Browse Source

Update for final review. p2ch12 code might get further tweaks.

Eli Stevens 6 years ago
parent
commit
f5d199ca86
44 changed files with 5677 additions and 404 deletions
  1. BIN
      data/p1ch3/ourpoints.hdf5
  2. BIN
      data/p1ch3/ourpoints.t
  3. 1 0
      data/p1ch6/cifar-10-batches-py/readme.html
  4. BIN
      data/part2/models/alternate_cls_2019-06-16_14.18.58_ch12-other-model.best.state
  5. BIN
      data/part2/models/cls_2019-06-23_14.57.45_redo.best.state
  6. BIN
      data/part2/models/seg_2019-06-22_22.55.11_ch12-no-aug-bn-all-metrics-more-samples.750000.state
  7. 59 59
      p1ch3/1_tensors.ipynb
  8. 13 13
      p1ch4/1_tabular_wine.ipynb
  9. 29 29
      p1ch5/1_parameter_estimation.ipynb
  10. 51 51
      p1ch6/1_neural_networks.ipynb
  11. 4 4
      p1ch6/2_activation_functions.ipynb
  12. 1 1
      p1ch6/3_nn_module_subclassing.ipynb
  13. 26 23
      p2ch09/dsets.py
  14. 1 1
      p2ch09/vis.py
  15. 27 24
      p2ch09_explore_data.ipynb
  16. 26 26
      p2ch10/dsets.py
  17. 32 20
      p2ch10/model.py
  18. 246 92
      p2ch10/training.py
  19. 59 0
      p2ch11/1_final_metric_f1_score.ipynb
  20. 0 0
      p2ch11/__init__.py
  21. 593 0
      p2ch11/diagnose.py
  22. 316 0
      p2ch11/dsets.py
  23. 63 0
      p2ch11/model.py
  24. 328 0
      p2ch11/model_segmentation.py
  25. 63 0
      p2ch11/prepcache.py
  26. 420 0
      p2ch11/training.py
  27. 87 0
      p2ch11/vis.py
  28. 0 0
      p2ch12/__init__.py
  29. 378 0
      p2ch12/diagnose.py
  30. 569 0
      p2ch12/dsets.py
  31. 41 0
      p2ch12/model.py
  32. 109 0
      p2ch12/model_cls.py
  33. 46 0
      p2ch12/model_seg.py
  34. 72 0
      p2ch12/prepcache.py
  35. 92 0
      p2ch12/screencts.py
  36. 454 0
      p2ch12/train_cls.py
  37. 538 0
      p2ch12/train_seg.py
  38. 554 0
      p2ch12/training.py
  39. 86 0
      p2ch12/vis.py
  40. 115 0
      p2ch12_explore_data.ipynb
  41. 122 0
      p2ch12_explore_diagnose.ipynb
  42. 1 1
      util/disk.py
  43. 4 4
      util/unet.py
  44. 51 56
      util/util.py

BIN
data/p1ch3/ourpoints.hdf5


BIN
data/p1ch3/ourpoints.t


+ 1 - 0
data/p1ch6/cifar-10-batches-py/readme.html

@@ -0,0 +1 @@
+<meta HTTP-EQUIV="REFRESH" content="0; url=http://www.cs.toronto.edu/~kriz/cifar.html">

BIN
data/part2/models/alternate_cls_2019-06-16_14.18.58_ch12-other-model.best.state


BIN
data/part2/models/cls_2019-06-23_14.57.45_redo.best.state


BIN
data/part2/models/seg_2019-06-22_22.55.11_ch12-no-aug-bn-all-metrics-more-samples.750000.state


+ 59 - 59
p1ch3/1_tensors.ipynb

@@ -140,12 +140,12 @@
    "outputs": [],
    "source": [
     "points = torch.zeros(6) # <1>\n",
-    "points[0] = 1.0 # <2>\n",
-    "points[1] = 4.0\n",
-    "points[2] = 2.0\n",
-    "points[3] = 1.0\n",
-    "points[4] = 3.0\n",
-    "points[5] = 5.0"
+    "points[0] = 4.0 # <2>\n",
+    "points[1] = 1.0\n",
+    "points[2] = 5.0\n",
+    "points[3] = 3.0\n",
+    "points[4] = 2.0\n",
+    "points[5] = 1.0"
    ]
   },
   {
@@ -156,7 +156,7 @@
     {
      "data": {
       "text/plain": [
-       "tensor([1., 4., 2., 1., 3., 5.])"
+       "tensor([4., 1., 5., 3., 2., 1.])"
       ]
      },
      "execution_count": 9,
@@ -165,7 +165,7 @@
     }
    ],
    "source": [
-    "points = torch.tensor([1.0, 4.0, 2.0, 1.0, 3.0, 5.0])\n",
+    "points = torch.tensor([4.0, 1.0, 5.0, 3.0, 2.0, 1.0])\n",
     "points"
    ]
   },
@@ -177,7 +177,7 @@
     {
      "data": {
       "text/plain": [
-       "(1.0, 4.0)"
+       "(4.0, 1.0)"
       ]
      },
      "execution_count": 10,
@@ -197,9 +197,9 @@
     {
      "data": {
       "text/plain": [
-       "tensor([[1., 4.],\n",
-       "        [2., 1.],\n",
-       "        [3., 5.]])"
+       "tensor([[4., 1.],\n",
+       "        [5., 3.],\n",
+       "        [2., 1.]])"
       ]
      },
      "execution_count": 11,
@@ -208,7 +208,7 @@
     }
    ],
    "source": [
-    "points = torch.tensor([[1.0, 4.0], [2.0, 1.0], [3.0, 5.0]])\n",
+    "points = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])\n",
     "points"
    ]
   },
@@ -263,9 +263,9 @@
     {
      "data": {
       "text/plain": [
-       "tensor([[1., 4.],\n",
-       "        [2., 1.],\n",
-       "        [3., 5.]])"
+       "tensor([[4., 1.],\n",
+       "        [5., 3.],\n",
+       "        [2., 1.]])"
       ]
      },
      "execution_count": 14,
@@ -274,7 +274,7 @@
     }
    ],
    "source": [
-    "points = torch.FloatTensor([[1.0, 4.0], [2.0, 1.0], [3.0, 5.0]])\n",
+    "points = torch.FloatTensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])\n",
     "points"
    ]
   },
@@ -286,7 +286,7 @@
     {
      "data": {
       "text/plain": [
-       "tensor(4.)"
+       "tensor(1.)"
       ]
      },
      "execution_count": 15,
@@ -306,7 +306,7 @@
     {
      "data": {
       "text/plain": [
-       "tensor([1., 4.])"
+       "tensor([4., 1.])"
       ]
      },
      "execution_count": 16,
@@ -326,12 +326,12 @@
     {
      "data": {
       "text/plain": [
-       " 1.0\n",
        " 4.0\n",
-       " 2.0\n",
        " 1.0\n",
-       " 3.0\n",
        " 5.0\n",
+       " 3.0\n",
+       " 2.0\n",
+       " 1.0\n",
        "[torch.FloatStorage of size 6]"
       ]
      },
@@ -341,7 +341,7 @@
     }
    ],
    "source": [
-    "points = torch.tensor([[1.0, 4.0], [2.0, 1.0], [3.0, 5.0]])\n",
+    "points = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])\n",
     "points.storage()"
    ]
   },
@@ -353,7 +353,7 @@
     {
      "data": {
       "text/plain": [
-       "1.0"
+       "4.0"
       ]
      },
      "execution_count": 18,
@@ -374,7 +374,7 @@
     {
      "data": {
       "text/plain": [
-       "4.0"
+       "1.0"
       ]
      },
      "execution_count": 19,
@@ -394,9 +394,9 @@
     {
      "data": {
       "text/plain": [
-       "tensor([[2., 4.],\n",
-       "        [2., 1.],\n",
-       "        [3., 5.]])"
+       "tensor([[2., 1.],\n",
+       "        [5., 3.],\n",
+       "        [2., 1.]])"
       ]
      },
      "execution_count": 20,
@@ -405,7 +405,7 @@
     }
    ],
    "source": [
-    "points = torch.tensor([[1.0, 4.0], [2.0, 1.0], [3.0, 5.0]])\n",
+    "points = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])\n",
     "points_storage = points.storage()\n",
     "points_storage[0] = 2.0\n",
     "points"
@@ -428,7 +428,7 @@
     }
    ],
    "source": [
-    "points = torch.tensor([[1.0, 4.0], [2.0, 1.0], [3.0, 5.0]])\n",
+    "points = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])\n",
     "second_point = points[1]\n",
     "second_point.storage_offset()"
    ]
@@ -510,7 +510,7 @@
     }
    ],
    "source": [
-    "points = torch.tensor([[1.0, 4.0], [2.0, 1.0], [3.0, 5.0]])\n",
+    "points = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])\n",
     "second_point = points[1]\n",
     "second_point.size()"
    ]
@@ -563,9 +563,9 @@
     {
      "data": {
       "text/plain": [
-       "tensor([[ 1.,  4.],\n",
-       "        [10.,  1.],\n",
-       "        [ 3.,  5.]])"
+       "tensor([[ 4.,  1.],\n",
+       "        [10.,  3.],\n",
+       "        [ 2.,  1.]])"
       ]
      },
      "execution_count": 28,
@@ -574,7 +574,7 @@
     }
    ],
    "source": [
-    "points = torch.tensor([[1.0, 4.0], [2.0, 1.0], [3.0, 5.0]])\n",
+    "points = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])\n",
     "second_point = points[1]\n",
     "second_point[0] = 10.0\n",
     "points"
@@ -588,9 +588,9 @@
     {
      "data": {
       "text/plain": [
-       "tensor([[1., 4.],\n",
-       "        [2., 1.],\n",
-       "        [3., 5.]])"
+       "tensor([[4., 1.],\n",
+       "        [5., 3.],\n",
+       "        [2., 1.]])"
       ]
      },
      "execution_count": 29,
@@ -599,7 +599,7 @@
     }
    ],
    "source": [
-    "points = torch.tensor([[1.0, 4.0], [2.0, 1.0], [3.0, 5.0]])\n",
+    "points = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])\n",
     "second_point = points[1].clone()\n",
     "second_point[0] = 10.0\n",
     "points"
@@ -613,9 +613,9 @@
     {
      "data": {
       "text/plain": [
-       "tensor([[1., 4.],\n",
-       "        [2., 1.],\n",
-       "        [3., 5.]])"
+       "tensor([[4., 1.],\n",
+       "        [5., 3.],\n",
+       "        [2., 1.]])"
       ]
      },
      "execution_count": 30,
@@ -624,7 +624,7 @@
     }
    ],
    "source": [
-    "points = torch.tensor([[1.0, 4.0], [2.0, 1.0], [3.0, 5.0]])\n",
+    "points = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])\n",
     "points"
    ]
   },
@@ -636,8 +636,8 @@
     {
      "data": {
       "text/plain": [
-       "tensor([[1., 2., 3.],\n",
-       "        [4., 1., 5.]])"
+       "tensor([[4., 5., 2.],\n",
+       "        [1., 3., 1.]])"
       ]
      },
      "execution_count": 31,
@@ -840,8 +840,8 @@
     {
      "data": {
       "text/plain": [
-       "tensor([[1., 2., 3.],\n",
-       "        [4., 1., 5.]])"
+       "tensor([[4., 5., 2.],\n",
+       "        [1., 3., 1.]])"
       ]
      },
      "execution_count": 41,
@@ -850,7 +850,7 @@
     }
    ],
    "source": [
-    "points = torch.tensor([[1.0, 4.0], [2.0, 1.0], [3.0, 5.0]])\n",
+    "points = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])\n",
     "points_t = points.t()\n",
     "points_t"
    ]
@@ -863,12 +863,12 @@
     {
      "data": {
       "text/plain": [
-       " 1.0\n",
        " 4.0\n",
-       " 2.0\n",
        " 1.0\n",
-       " 3.0\n",
        " 5.0\n",
+       " 3.0\n",
+       " 2.0\n",
+       " 1.0\n",
        "[torch.FloatStorage of size 6]"
       ]
      },
@@ -909,8 +909,8 @@
     {
      "data": {
       "text/plain": [
-       "tensor([[1., 2., 3.],\n",
-       "        [4., 1., 5.]])"
+       "tensor([[4., 5., 2.],\n",
+       "        [1., 3., 1.]])"
       ]
      },
      "execution_count": 44,
@@ -951,12 +951,12 @@
     {
      "data": {
       "text/plain": [
-       " 1.0\n",
+       " 4.0\n",
+       " 5.0\n",
        " 2.0\n",
+       " 1.0\n",
        " 3.0\n",
-       " 4.0\n",
        " 1.0\n",
-       " 5.0\n",
        "[torch.FloatStorage of size 6]"
       ]
      },
@@ -1036,7 +1036,7 @@
    "outputs": [],
    "source": [
     "# reset points back to original value\n",
-    "points = torch.tensor([[1.0, 4.0], [2.0, 1.0], [3.0, 4.0]])"
+    "points = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])"
    ]
   },
   {
@@ -1073,7 +1073,7 @@
     {
      "data": {
       "text/plain": [
-       "tensor([2., 3.])"
+       "tensor([5., 2.])"
       ]
      },
      "execution_count": 54,
@@ -1198,7 +1198,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "points_gpu = torch.tensor([[1.0, 4.0], [2.0, 1.0], [3.0, 4.0]], device='cuda')"
+    "points_gpu = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]], device='cuda')"
    ]
   },
   {

+ 13 - 13
p1ch4/1_tabular_wine.ipynb

@@ -8,7 +8,7 @@
    "source": [
     "import numpy as np\n",
     "import torch\n",
-    "torch.set_printoptions(edgeitems=2)"
+    "torch.set_printoptions(edgeitems=2, precision=2)"
    ]
   },
   {
@@ -104,11 +104,11 @@
     {
      "data": {
       "text/plain": [
-       "(tensor([[ 7.0000,  0.2700,  ...,  0.4500,  8.8000],\n",
-       "         [ 6.3000,  0.3000,  ...,  0.4900,  9.5000],\n",
+       "(tensor([[ 7.00,  0.27,  ...,  0.45,  8.80],\n",
+       "         [ 6.30,  0.30,  ...,  0.49,  9.50],\n",
        "         ...,\n",
-       "         [ 5.5000,  0.2900,  ...,  0.3800, 12.8000],\n",
-       "         [ 6.0000,  0.2100,  ...,  0.3200, 11.8000]]), torch.Size([4898, 11]))"
+       "         [ 5.50,  0.29,  ...,  0.38, 12.80],\n",
+       "         [ 6.00,  0.21,  ...,  0.32, 11.80]]), torch.Size([4898, 11]))"
       ]
      },
      "execution_count": 5,
@@ -222,8 +222,8 @@
     {
      "data": {
       "text/plain": [
-       "tensor([6.8548e+00, 2.7824e-01, 3.3419e-01, 6.3914e+00, 4.5772e-02, 3.5308e+01,\n",
-       "        1.3836e+02, 9.9403e-01, 3.1883e+00, 4.8985e-01, 1.0514e+01])"
+       "tensor([6.85e+00, 2.78e-01, 3.34e-01, 6.39e+00, 4.58e-02, 3.53e+01, 1.38e+02,\n",
+       "        9.94e-01, 3.19e+00, 4.90e-01, 1.05e+01])"
       ]
      },
      "execution_count": 10,
@@ -244,8 +244,8 @@
     {
      "data": {
       "text/plain": [
-       "tensor([7.1211e-01, 1.0160e-02, 1.4646e-02, 2.5726e+01, 4.7733e-04, 2.8924e+02,\n",
-       "        1.8061e+03, 8.9455e-06, 2.2801e-02, 1.3025e-02, 1.5144e+00])"
+       "tensor([7.12e-01, 1.02e-02, 1.46e-02, 2.57e+01, 4.77e-04, 2.89e+02, 1.81e+03,\n",
+       "        8.95e-06, 2.28e-02, 1.30e-02, 1.51e+00])"
       ]
      },
      "execution_count": 11,
@@ -266,11 +266,11 @@
     {
      "data": {
       "text/plain": [
-       "tensor([[ 1.7209e-01, -8.1764e-02,  ..., -3.4914e-01, -1.3930e+00],\n",
-       "        [-6.5743e-01,  2.1587e-01,  ...,  1.3467e-03, -8.2418e-01],\n",
+       "tensor([[ 1.72e-01, -8.18e-02,  ..., -3.49e-01, -1.39e+00],\n",
+       "        [-6.57e-01,  2.16e-01,  ...,  1.35e-03, -8.24e-01],\n",
        "        ...,\n",
-       "        [-1.6054e+00,  1.1666e-01,  ..., -9.6250e-01,  1.8574e+00],\n",
-       "        [-1.0129e+00, -6.7703e-01,  ..., -1.4882e+00,  1.0448e+00]])"
+       "        [-1.61e+00,  1.17e-01,  ..., -9.63e-01,  1.86e+00],\n",
+       "        [-1.01e+00, -6.77e-01,  ..., -1.49e+00,  1.04e+00]])"
       ]
      },
      "execution_count": 12,

File diff suppressed because it is too large
+ 29 - 29
p1ch5/1_parameter_estimation.ipynb


File diff suppressed because it is too large
+ 51 - 51
p1ch6/1_neural_networks.ipynb


File diff suppressed because it is too large
+ 4 - 4
p1ch6/2_activation_functions.ipynb


+ 1 - 1
p1ch6/3_nn_module_subclassing.ipynb

@@ -39,7 +39,7 @@
     "seq_model = nn.Sequential(\n",
     "            nn.Linear(1, 11), # <1>\n",
     "            nn.Tanh(),\n",
-    "            nn.Linear(11, 1)) # <2>\n",
+    "            nn.Linear(11, 1)) # <1>\n",
     "seq_model"
    ]
   },

+ 26 - 23
p2ch09/dsets.py

@@ -5,6 +5,8 @@ import glob
 import os
 import random
 
+from collections import namedtuple
+
 import SimpleITK as sitk
 
 import numpy as np
@@ -18,10 +20,12 @@ from util.logconf import logging
 
 log = logging.getLogger(__name__)
 # log.setLevel(logging.WARN)
-log.setLevel(logging.INFO)
+# log.setLevel(logging.INFO)
 log.setLevel(logging.DEBUG)
 
-raw_cache = getCache('part2ch08_raw')
+raw_cache = getCache('part2ch09_raw')
+
+NoduleInfoTuple = namedtuple('NoduleInfoTuple', 'isMalignant_bool, diameter_mm, series_uid, center_xyz')
 
 @functools.lru_cache(1)
 def getNoduleInfoList(requireDataOnDisk_bool=True):
@@ -61,7 +65,7 @@ def getNoduleInfoList(requireDataOnDisk_bool=True):
                     candidateDiameter_mm = annotationDiameter_mm
                     break
 
-            noduleInfo_list.append((isMalignant_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))
+            noduleInfo_list.append(NoduleInfoTuple(isMalignant_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))
 
     noduleInfo_list.sort(reverse=True)
     return noduleInfo_list
@@ -75,15 +79,11 @@ class Ct(object):
 
         # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
         # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
-        # This converts HU to g/cc.
-        ct_ary += 1000
-        ct_ary /= 1000
-
         # This gets rid of negative density stuff used to indicate out-of-FOV
-        ct_ary[ct_ary < 0] = 0
+        ct_ary[ct_ary < -1000] = -1000
 
         # This nukes any weird hotspots and clamps bone down
-        ct_ary[ct_ary > 2] = 2
+        ct_ary[ct_ary > 1000] = 1000
 
         self.series_uid = series_uid
         self.ary = ct_ary
@@ -116,7 +116,7 @@ class Ct(object):
 
             slice_list.append(slice(start_ndx, end_ndx))
 
-        ct_chunk = self.ary[slice_list]
+        ct_chunk = self.ary[tuple(slice_list)]
 
         return ct_chunk, center_irc
 
@@ -142,7 +142,6 @@ class LunaDataset(Dataset):
         if series_uid:
             self.noduleInfo_list = [x for x in self.noduleInfo_list if x[2] == series_uid]
 
-        # __init__ continued...
         if test_stride > 1:
             if isTestSet_bool:
                 self.noduleInfo_list = self.noduleInfo_list[::test_stride]
@@ -159,18 +158,22 @@ class LunaDataset(Dataset):
         return len(self.noduleInfo_list)
 
     def __getitem__(self, ndx):
-        sample_ndx = ndx
-
-        isMalignant_bool, diameter_mm, series_uid, center_xyz = self.noduleInfo_list[sample_ndx]
-
-        nodule_ary, center_irc = getCtRawNodule(series_uid, center_xyz, (32, 32, 32))
-
-        nodule_tensor = torch.from_numpy(nodule_ary)
+        nodule_tup = self.noduleInfo_list[ndx]
+        width_irc = (24, 48, 48)
+
+        nodule_ary, center_irc = getCtRawNodule(
+            nodule_tup.series_uid,
+            nodule_tup.center_xyz,
+            width_irc,
+        )
+        nodule_tensor = torch.from_numpy(nodule_ary).to(torch.float32)
         nodule_tensor = nodule_tensor.unsqueeze(0)
 
-        malignant_tensor = torch.tensor([isMalignant_bool], dtype=torch.float32)
-
-        return nodule_tensor, malignant_tensor, series_uid, center_irc
-
-
+        cls_tensor = torch.tensor([
+                not nodule_tup.isMalignant_bool,
+                nodule_tup.isMalignant_bool
+            ],
+            dtype=torch.long,
+        )
 
+        return nodule_tensor, cls_tensor, nodule_tup.series_uid, center_irc

+ 1 - 1
p2ch09/vis.py

@@ -4,7 +4,7 @@ import matplotlib.pyplot as plt
 
 from p2ch09.dsets import Ct, LunaDataset
 
-clim=(0.0, 1.3)
+clim=(-1000.0, 300)
 
 def findMalignantSamples(start_ndx=0, limit=100):
     ds = LunaDataset()

File diff suppressed because it is too large
+ 27 - 24
p2ch09_explore_data.ipynb


+ 26 - 26
p2ch10/dsets.py

@@ -5,6 +5,8 @@ import glob
 import os
 import random
 
+from collections import namedtuple
+
 import SimpleITK as sitk
 
 import numpy as np
@@ -18,11 +20,13 @@ from util.logconf import logging
 
 log = logging.getLogger(__name__)
 # log.setLevel(logging.WARN)
-log.setLevel(logging.INFO)
+# log.setLevel(logging.INFO)
 log.setLevel(logging.DEBUG)
 
 raw_cache = getCache('part2ch09_raw')
 
+NoduleInfoTuple = namedtuple('NoduleInfoTuple', 'isMalignant_bool, diameter_mm, series_uid, center_xyz')
+
 @functools.lru_cache(1)
 def getNoduleInfoList(requireDataOnDisk_bool=True):
     # We construct a set with all series_uids that are present on disk.
@@ -61,7 +65,7 @@ def getNoduleInfoList(requireDataOnDisk_bool=True):
                     candidateDiameter_mm = annotationDiameter_mm
                     break
 
-            noduleInfo_list.append((isMalignant_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))
+            noduleInfo_list.append(NoduleInfoTuple(isMalignant_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))
 
     noduleInfo_list.sort(reverse=True)
     return noduleInfo_list
@@ -75,15 +79,11 @@ class Ct(object):
 
         # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
         # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
-        # This converts HU to g/cc.
-        ct_ary += 1000
-        ct_ary /= 1000
-
         # This gets rid of negative density stuff used to indicate out-of-FOV
-        ct_ary[ct_ary < 0] = 0
+        ct_ary[ct_ary < -1000] = -1000
 
         # This nukes any weird hotspots and clamps bone down
-        ct_ary[ct_ary > 2] = 2
+        ct_ary[ct_ary > 1000] = 1000
 
         self.series_uid = series_uid
         self.ary = ct_ary
@@ -116,7 +116,7 @@ class Ct(object):
 
             slice_list.append(slice(start_ndx, end_ndx))
 
-        ct_chunk = self.ary[slice_list]
+        ct_chunk = self.ary[tuple(slice_list)]
 
         return ct_chunk, center_irc
 
@@ -131,6 +131,7 @@ def getCtRawNodule(series_uid, center_xyz, width_irc):
     ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
     return ct_chunk, center_irc
 
+
 class LunaDataset(Dataset):
     def __init__(self,
                  test_stride=0,
@@ -143,7 +144,6 @@ class LunaDataset(Dataset):
         if series_uid:
             self.noduleInfo_list = [x for x in self.noduleInfo_list if x[2] == series_uid]
 
-        # __init__ continued...
         if test_stride > 1:
             if isTestSet_bool:
                 self.noduleInfo_list = self.noduleInfo_list[::test_stride]
@@ -165,26 +165,26 @@ class LunaDataset(Dataset):
             "testing" if isTestSet_bool else "training",
         ))
 
-
     def __len__(self):
-        # if self.ratio_int:
-        #     return min(len(self.benignIndex_list), len(self.malignantIndex_list)) * 4 * 90
-        # else:
         return len(self.noduleInfo_list)
 
     def __getitem__(self, ndx):
-        sample_ndx = ndx
-
-        isMalignant_bool, _diameter_mm, series_uid, center_xyz = self.noduleInfo_list[sample_ndx]
-
-        nodule_ary, center_irc = getCtRawNodule(series_uid, center_xyz, (32, 32, 32))
-
-        nodule_tensor = torch.from_numpy(nodule_ary)
+        nodule_tup = self.noduleInfo_list[ndx]
+        width_irc = (24, 48, 48)
+
+        nodule_ary, center_irc = getCtRawNodule(
+            nodule_tup.series_uid,
+            nodule_tup.center_xyz,
+            width_irc,
+        )
+        nodule_tensor = torch.from_numpy(nodule_ary).to(torch.float32)
         nodule_tensor = nodule_tensor.unsqueeze(0)
 
-        malignant_tensor = torch.tensor([isMalignant_bool], dtype=torch.float32)
-
-        return nodule_tensor, malignant_tensor, series_uid, center_irc
-
-
+        cls_tensor = torch.tensor([
+                not nodule_tup.isMalignant_bool,
+                nodule_tup.isMalignant_bool
+            ],
+            dtype=torch.long,
+        )
 
+        return nodule_tensor, cls_tensor, nodule_tup.series_uid, center_irc

+ 32 - 20
p2ch10/model.py

@@ -1,3 +1,4 @@
+import math
 
 import torch
 from torch import nn as nn
@@ -9,44 +10,55 @@ log = logging.getLogger(__name__)
 # log.setLevel(logging.INFO)
 log.setLevel(logging.DEBUG)
 
+
 class LunaModel(nn.Module):
     def __init__(self, layer_count=4, in_channels=1, conv_channels=8):
         super().__init__()
 
+        self.input_batchnorm = nn.BatchNorm2d(1)
+
         layer_list = []
         for layer_ndx in range(layer_count):
             layer_list += [
-                nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=False),
-                nn.BatchNorm3d(conv_channels), # eli: will assume that p1ch6 doesn't use this
-                nn.LeakyReLU(inplace=True), # eli: will assume plan ReLU
-                nn.Dropout3d(p=0.2),  # eli: will assume that p1ch6 doesn't use this
-
-                nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=False),
-                nn.BatchNorm3d(conv_channels),
-                nn.LeakyReLU(inplace=True),
-                nn.Dropout3d(p=0.2),
+                nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True),
+                nn.ReLU(inplace=True),
+                nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=True),
+                nn.ReLU(inplace=True),
 
                 nn.MaxPool3d(2, 2),
- # tag::model_init[]
            ]
 
             in_channels = conv_channels
             conv_channels *= 2
 
         self.convAndPool_seq = nn.Sequential(*layer_list)
-        self.fullyConnected_layer = nn.Linear(512, 1)
-        self.final = nn.Hardtanh(min_val=0.0, max_val=1.0)
+        self.fullyConnected_layer = nn.Linear(576, 2)
+        self.final = nn.Softmax(dim=1)
+
+        self._init_weights()
+
+    # see also https://github.com/pytorch/pytorch/issues/18182
+    def _init_weights(self):
+        for m in self.modules():
+            if type(m) in {
+                nn.Linear,
+                nn.Conv3d,
+                nn.Conv2d,
+                nn.ConvTranspose2d,
+                nn.ConvTranspose3d,
+            }:
+                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out', nonlinearity='relu')
+                if m.bias is not None:
+                    fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
+                    bound = 1 / math.sqrt(fan_out)
+                    nn.init.normal_(m.bias, -bound, bound)
 
 
     def forward(self, input_batch):
-        conv_output = self.convAndPool_seq(input_batch)
+        bn_output = self.input_batchnorm(input_batch)
+        conv_output = self.convAndPool_seq(bn_output)
         conv_flat = conv_output.view(conv_output.size(0), -1)
+        classifier_output = self.fullyConnected_layer(conv_flat)
 
-        try:
-            classifier_output = self.fullyConnected_layer(conv_flat)
-        except:
-            log.debug(conv_flat.size())
-            raise
+        return classifier_output, self.final(classifier_output)
 
-        classifier_output = self.final(classifier_output)
-        return classifier_output

+ 246 - 92
p2ch10/training.py

@@ -4,6 +4,7 @@ import os
 import sys
 
 import numpy as np
+
 from tensorboardX import SummaryWriter
 
 import torch
@@ -25,6 +26,7 @@ log.setLevel(logging.INFO)
 METRICS_LABEL_NDX=0
 METRICS_PRED_NDX=1
 METRICS_LOSS_NDX=2
+METRICS_SIZE = 3
 
 class LunaTrainingApp(object):
     def __init__(self, sys_argv=None):
@@ -48,43 +50,92 @@ class LunaTrainingApp(object):
             type=int,
         )
 
+        parser.add_argument('--tb-prefix',
+            default='p2ch10',
+            help="Data prefix to use for Tensorboard run. Defaults to chapter.",
+        )
+
+        parser.add_argument('comment',
+            help="Comment suffix for Tensorboard run.",
+            nargs='?',
+            default='none',
+        )
+
         self.cli_args = parser.parse_args(sys_argv)
-        self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
+        self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
 
-    def main(self):
-        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+        self.trn_writer = None
+        self.tst_writer = None
+        self.totalTrainingSamples_count = 0
 
         self.use_cuda = torch.cuda.is_available()
         self.device = torch.device("cuda" if self.use_cuda else "cpu")
 
-        self.model = LunaModel()
+        self.model = self.initModel()
+        self.optimizer = self.initOptimizer()
+
+    def initModel(self):
+        model = LunaModel()
         if self.use_cuda:
             if torch.cuda.device_count() > 1:
-                self.model = nn.DataParallel(self.model)
-
-            self.model = self.model.to(self.device)
-        self.optimizer = SGD(self.model.parameters(), lr=0.01, momentum=0.9)
+                model = nn.DataParallel(model)
+            model = model.to(self.device)
+        return model
+
+    def initOptimizer(self):
+        return SGD(self.model.parameters(), lr=0.001, momentum=0.99)
+        # return Adam(self.model.parameters())
+
+    def initTrainDl(self):
+        train_ds = LunaDataset(
+            test_stride=10,
+            isTestSet_bool=False,
+        )
 
         train_dl = DataLoader(
-            LunaDataset(
-                test_stride=10,
-                isTestSet_bool=False,
-            ),
+            train_ds,
             batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
             num_workers=self.cli_args.num_workers,
             pin_memory=self.use_cuda,
         )
 
+        return train_dl
+
+    def initTestDl(self):
+        test_ds = LunaDataset(
+            test_stride=10,
+            isTestSet_bool=True,
+        )
+
         test_dl = DataLoader(
-            LunaDataset(
-                test_stride=10,
-                isTestSet_bool=True,
-            ),
+            test_ds,
             batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
             num_workers=self.cli_args.num_workers,
             pin_memory=self.use_cuda,
         )
 
+        return test_dl
+
+    def initTensorboardWriters(self):
+        if self.trn_writer is None:
+            log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
+
+            self.trn_writer = SummaryWriter(log_dir=log_dir + '_trn_cls_' + self.cli_args.comment)
+            self.tst_writer = SummaryWriter(log_dir=log_dir + '_tst_cls_' + self.cli_args.comment)
+# eng::tb_writer[]
+
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        train_dl = self.initTrainDl()
+        test_dl = self.initTestDl()
+
+        self.initTensorboardWriters()
+        # self.logModelMetrics(self.model)
+
+        # best_score = 0.0
+
         for epoch_ndx in range(1, self.cli_args.epochs + 1):
 
             log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
@@ -96,120 +147,223 @@ class LunaTrainingApp(object):
                 (torch.cuda.device_count() if self.use_cuda else 1),
             ))
 
-            # Training loop, very similar to below
-            self.model.train()
-            trainingMetrics_tensor = torch.zeros(3, len(train_dl.dataset), 1)
+            trnMetrics_tensor = self.doTraining(epoch_ndx, train_dl)
+            self.logMetrics(epoch_ndx, 'trn', trnMetrics_tensor)
+
+            tstMetrics_tensor = self.doTesting(epoch_ndx, test_dl)
+            self.logMetrics(epoch_ndx, 'tst', tstMetrics_tensor)
+
+        if hasattr(self, 'trn_writer'):
+            self.trn_writer.close()
+            self.tst_writer.close()
+
+
+    def doTraining(self, epoch_ndx, train_dl):
+        self.model.train()
+        trainingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(train_dl.dataset)).to(self.device)
+        batch_iter = enumerateWithEstimate(
+            train_dl,
+            "E{} Training".format(epoch_ndx),
+            start_ndx=train_dl.num_workers,
+        )
+        for batch_ndx, batch_tup in batch_iter:
+            self.optimizer.zero_grad()
+
+            loss_var = self.computeBatchLoss(
+                batch_ndx,
+                batch_tup,
+                train_dl.batch_size,
+                trainingMetrics_devtensor
+            )
+
+            loss_var.backward()
+            self.optimizer.step()
+            del loss_var
+
+        self.totalTrainingSamples_count += trainingMetrics_devtensor.size(1)
+
+        return trainingMetrics_devtensor.to('cpu')
+
+
+    def doTesting(self, epoch_ndx, test_dl):
+        with torch.no_grad():
+            self.model.eval()
+            testingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset)).to(self.device)
             batch_iter = enumerateWithEstimate(
-                train_dl,
-                "E{} Training".format(epoch_ndx),
-                start_ndx=train_dl.num_workers,
+                test_dl,
+                "E{} Testing ".format(epoch_ndx),
+                start_ndx=test_dl.num_workers,
             )
             for batch_ndx, batch_tup in batch_iter:
-                self.optimizer.zero_grad()
-                loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
-                loss_var.backward()
-                self.optimizer.step()
-                del loss_var
-
-            # Testing loop, very similar to above, but simplified
-            with torch.no_grad():
-                self.model.eval()
-                testingMetrics_tensor = torch.zeros(3, len(test_dl.dataset), 1)
-                batch_iter = enumerateWithEstimate(
-                    test_dl,
-                    "E{} Testing ".format(epoch_ndx),
-                    start_ndx=test_dl.num_workers,
-                )
-                for batch_ndx, batch_tup in batch_iter:
-                    self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
-
-            self.logMetrics(epoch_ndx, trainingMetrics_tensor, testingMetrics_tensor)
-
-
-    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_tensor):
-        input_tensor, label_tensor, _series_list, _center_list = batch_tup
+                self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_devtensor)
+
+        return testingMetrics_devtensor.to('cpu')
 
-        input_devtensor = input_tensor.to(self.device)
-        label_devtensor = label_tensor.to(self.device)
 
-        prediction_devtensor = self.model(input_devtensor)
-        loss_devtensor = nn.MSELoss(reduction='none')(prediction_devtensor, label_devtensor)
 
+    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_devtensor):
+        input_tensor, label_tensor, _series_list, _center_list = batch_tup
+
+        input_devtensor = input_tensor.to(self.device, non_blocking=True)
+        label_devtensor = label_tensor.to(self.device, non_blocking=True)
 
+        logits_devtensor, probability_devtensor = self.model(input_devtensor)
+
+        loss_func = nn.CrossEntropyLoss(reduction='none')
+        loss_devtensor = loss_func(logits_devtensor, label_devtensor[:,1])
         start_ndx = batch_ndx * batch_size
         end_ndx = start_ndx + label_tensor.size(0)
-        metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_tensor
-        metrics_tensor[METRICS_PRED_NDX, start_ndx:end_ndx] = prediction_devtensor.to('cpu')
-        metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_devtensor.to('cpu')
 
-        # TODO: replace with torch.autograd.detect_anomaly
-        # assert np.isfinite(metrics_tensor).all()
+        metrics_devtensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_devtensor[:,1]
+        metrics_devtensor[METRICS_PRED_NDX, start_ndx:end_ndx] = probability_devtensor[:,1]
+        metrics_devtensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_devtensor
 
         return loss_devtensor.mean()
 
 
-    def logMetrics(self,
-                   epoch_ndx,
-                   trainingMetrics_tensor,
-                   testingMetrics_tensor,
-                   classificationThreshold_float=0.5,
-                   ):
+    def logMetrics(
+            self,
+            epoch_ndx,
+            mode_str,
+            metrics_tensor,
+    ):
         log.info("E{} {}".format(
             epoch_ndx,
             type(self).__name__,
         ))
 
-        for mode_str, metrics_tensor in [('trn', trainingMetrics_tensor), ('tst', testingMetrics_tensor)]:
-            metrics_ary = metrics_tensor.detach().numpy()[:,:,0]
-            assert np.isfinite(metrics_ary).all()
+        metrics_ary = metrics_tensor.cpu().detach().numpy()
+#         assert np.isfinite(metrics_ary).all()
 
-            benLabel_mask = metrics_ary[METRICS_LABEL_NDX] <= classificationThreshold_float
-            benPred_mask = metrics_ary[METRICS_PRED_NDX] <= classificationThreshold_float
+        benLabel_mask = metrics_ary[METRICS_LABEL_NDX] <= 0.5
+        benPred_mask = metrics_ary[METRICS_PRED_NDX] <= 0.5
 
-            malLabel_mask = ~benLabel_mask
-            malPred_mask = ~benPred_mask
+        malLabel_mask = ~benLabel_mask
+        malPred_mask = ~benPred_mask
 
-            benLabel_count = benLabel_mask.sum()
-            malLabel_count = malLabel_mask.sum()
+        benLabel_count = benLabel_mask.sum()
+        malLabel_count = malLabel_mask.sum()
 
-            benCorrect_count = (benLabel_mask & benPred_mask).sum()
-            malCorrect_count = (malLabel_mask & malPred_mask).sum()
+        benCorrect_count = (benLabel_mask & benPred_mask).sum()
+        malCorrect_count = (malLabel_mask & malPred_mask).sum()
 
-            metrics_dict = {}
+        # trueNeg_count = benCorrect_count = (benLabel_mask & benPred_mask).sum()
+        # truePos_count = malCorrect_count = (malLabel_mask & malPred_mask).sum()
+        #
+        # falsePos_count = benLabel_count - benCorrect_count
+        # falseNeg_count = malLabel_count - malCorrect_count
 
-            metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
-            metrics_dict['loss/ben'] = metrics_ary[METRICS_LOSS_NDX, benLabel_mask].mean()
-            metrics_dict['loss/mal'] = metrics_ary[METRICS_LOSS_NDX, malLabel_mask].mean()
+        # log.info(['min loss', metrics_ary[METRICS_LOSS_NDX, benLabel_mask].min(), metrics_ary[METRICS_LOSS_NDX, malLabel_mask].min()])
+        # log.info(['max loss', metrics_ary[METRICS_LOSS_NDX, benLabel_mask].max(), metrics_ary[METRICS_LOSS_NDX, malLabel_mask].max()])
 
-            metrics_dict['correct/all'] = (malCorrect_count + benCorrect_count) / metrics_ary.shape[1] * 100
-            metrics_dict['correct/ben'] = (benCorrect_count) / benLabel_count * 100
-            metrics_dict['correct/mal'] = (malCorrect_count) / malLabel_count * 100
 
+        metrics_dict = {}
+        metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
+        metrics_dict['loss/ben'] = metrics_ary[METRICS_LOSS_NDX, benLabel_mask].mean()
+        metrics_dict['loss/mal'] = metrics_ary[METRICS_LOSS_NDX, malLabel_mask].mean()
 
+        metrics_dict['correct/all'] = (malCorrect_count + benCorrect_count) / metrics_ary.shape[1] * 100
+        metrics_dict['correct/ben'] = (benCorrect_count) / benLabel_count * 100
+        metrics_dict['correct/mal'] = (malCorrect_count) / malLabel_count * 100
 
-
-            log.info(("E{} {:8} "
-                     + "{loss/all:.4f} loss, "
-                     + "{correct/all:-5.1f}% correct"
-                      ).format(
+        log.info(
+            ("E{} {:8} "
+                 + "{loss/all:.4f} loss, "
+                 + "{correct/all:-5.1f}% correct, "
+            ).format(
                 epoch_ndx,
                 mode_str,
                 **metrics_dict,
-            ))
-            log.info(("E{} {:8} "
-                     + "{loss/ben:.4f} loss, "
-                     + "{correct/ben:-5.1f}% correct").format(
+            )
+        )
+        log.info(
+            ("E{} {:8} "
+                 + "{loss/ben:.4f} loss, "
+                 + "{correct/ben:-5.1f}% correct ({benCorrect_count:} of {benLabel_count:})"
+            ).format(
                 epoch_ndx,
                 mode_str + '_ben',
+                benCorrect_count=benCorrect_count,
+                benLabel_count=benLabel_count,
                 **metrics_dict,
-            ))
-            log.info(("E{} {:8} "
-                     + "{loss/mal:.4f} loss, "
-                     + "{correct/mal:-5.1f}% correct").format(
+            )
+        )
+        log.info(
+            ("E{} {:8} "
+                 + "{loss/mal:.4f} loss, "
+                 + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
+            ).format(
                 epoch_ndx,
                 mode_str + '_mal',
+                malCorrect_count=malCorrect_count,
+                malLabel_count=malLabel_count,
                 **metrics_dict,
-            ))
+            )
+        )
+
+        writer = getattr(self, mode_str + '_writer')
+
+        for key, value in metrics_dict.items():
+            writer.add_scalar(key, value, self.totalTrainingSamples_count)
+
+        writer.add_pr_curve(
+            'pr',
+            metrics_ary[METRICS_LABEL_NDX],
+            metrics_ary[METRICS_PRED_NDX],
+            self.totalTrainingSamples_count,
+        )
+
+        bins = [x/50.0 for x in range(51)]
+
+        benHist_mask = benLabel_mask & (metrics_ary[METRICS_PRED_NDX] > 0.01)
+        malHist_mask = malLabel_mask & (metrics_ary[METRICS_PRED_NDX] < 0.99)
+
+        if benHist_mask.any():
+            writer.add_histogram(
+                'is_ben',
+                metrics_ary[METRICS_PRED_NDX, benHist_mask],
+                self.totalTrainingSamples_count,
+                bins=bins,
+            )
+        if malHist_mask.any():
+            writer.add_histogram(
+                'is_mal',
+                metrics_ary[METRICS_PRED_NDX, malHist_mask],
+                self.totalTrainingSamples_count,
+                bins=bins,
+            )
+
+        # score = 1 \
+        #     + metrics_dict['pr/f1_score'] \
+        #     - metrics_dict['loss/mal'] * 0.01 \
+        #     - metrics_dict['loss/all'] * 0.0001
+        #
+        # return score
+
+    # def logModelMetrics(self, model):
+    #     writer = getattr(self, 'trn_writer')
+    #
+    #     model = getattr(model, 'module', model)
+    #
+    #     for name, param in model.named_parameters():
+    #         if param.requires_grad:
+    #             min_data = float(param.data.min())
+    #             max_data = float(param.data.max())
+    #             max_extent = max(abs(min_data), abs(max_data))
+    #
+    #             # bins = [x/50*max_extent for x in range(-50, 51)]
+    #
+    #             try:
+    #                 writer.add_histogram(
+    #                     name.rsplit('.', 1)[-1] + '/' + name,
+    #                     param.data.cpu().numpy(),
+    #                     # metrics_ary[METRICS_PRED_NDX, benHist_mask],
+    #                     self.totalTrainingSamples_count,
+    #                     # bins=bins,
+    #                 )
+    #             except Exception as e:
+    #                 log.error([min_data, max_data])
+    #                 raise
 
 
 if __name__ == '__main__':

File diff suppressed because it is too large
+ 59 - 0
p2ch11/1_final_metric_f1_score.ipynb


+ 0 - 0
p2ch11/__init__.py


+ 593 - 0
p2ch11/diagnose.py

@@ -0,0 +1,593 @@
+import argparse
+import datetime
+import glob
+import os
+import sys
+
+import numpy as np
+from tensorboardX import SummaryWriter
+
+import torch
+import torch.nn as nn
+import torch.optim
+
+from torch.optim import SGD, Adam
+from torch.utils.data import DataLoader
+
+from util.util import enumerateWithEstimate
+from .dsets import Luna2dSegmentationDataset, LunaClassificationDataset, getCt, getNoduleInfoList
+from util.logconf import logging
+from util.util import xyz2irc, irc2xyz
+from .model import UNetWrapper, LunaModel
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+# Used for computeBatchLoss and logMetrics to index into metrics_tensor/metrics_ary
+# METRICS_LABEL_NDX=0
+# METRICS_PRED_NDX=1
+# METRICS_LOSS_NDX=2
+# METRICS_MAL_LOSS_NDX=3
+# METRICS_BEN_LOSS_NDX=4
+# METRICS_LUNG_LOSS_NDX=5
+# METRICS_MASKLOSS_NDX=2
+# METRICS_MALLOSS_NDX=3
+
+
+METRICS_LOSS_NDX = 0
+METRICS_LABEL_NDX = 1
+METRICS_MFOUND_NDX = 2
+
+METRICS_MOK_NDX = 3
+METRICS_MTP_NDX = 4
+METRICS_MFN_NDX = 5
+METRICS_MFP_NDX = 6
+METRICS_BTP_NDX = 7
+METRICS_BFN_NDX = 8
+METRICS_BFP_NDX = 9
+
+METRICS_MAL_LOSS_NDX = 10
+METRICS_BEN_LOSS_NDX = 11
+METRICS_SIZE = 12
+
+
+
+
+class LunaDiagnoseApp(object):
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            log.debug(sys.argv)
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=4,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+
+        parser.add_argument('--series-uid',
+            help='Limit inference to this Series UID only.',
+            default=None,
+            type=str,
+        )
+
+
+        parser.add_argument('segmentation_path',
+            help="Path to the saved segmentation model",
+            nargs='?',
+            default=None,
+        )
+
+        parser.add_argument('classification_path',
+            help="Path to the saved classification model",
+            nargs='?',
+            default=None,
+        )
+
+        parser.add_argument('--tb-prefix',
+            default='p2ch10',
+            help="Data prefix to use for Tensorboard run. Defaults to chapter.",
+        )
+
+
+        self.cli_args = parser.parse_args(sys_argv)
+        # self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
+
+        self.use_cuda = torch.cuda.is_available()
+        self.device = torch.device("cuda" if self.use_cuda else "cpu")
+
+        # self.optimizer = self.initOptimizer()
+
+        if not self.cli_args.segmentation_path:
+            file_path = os.path.join('data-unversioned', 'models', self.cli_args.tb_prefix, 'seg_{}_{}.{}.state'.format('*', '*', 'best'))
+            # log.debug(file_path)
+            self.cli_args.segmentation_path = glob.glob(file_path)[-1]
+
+        log.debug(self.cli_args.segmentation_path)
+
+        # if not self.cli_args.classification_path:
+        #     file_path = os.path.join('data-unversioned', 'models', self.cli_args.tb_prefix, 'cls_{}_{}.{}.state'.format('*', '*', 'best'))
+        #     self.cli_args.classification_path = glob.glob(file_path)[-1]
+
+        self.seg_model, self.cls_model = self.initModels()
+
+
+
+    def initModels(self):
+        log.debug(self.cli_args.segmentation_path)
+        seg_dict = torch.load(self.cli_args.segmentation_path)
+
+        seg_model = UNetWrapper(in_channels=8, n_classes=2, depth=5, wf=6, padding=True, batch_norm=True, up_mode='upconv')
+        seg_model.load_state_dict(seg_dict['model_state'])
+        seg_model.eval()
+
+        # cls_dict = torch.load(self.cli_args.segmentation_path)
+
+        cls_model = LunaModel()
+        # cls_model.load_state_dict(cls_dict['model_state'])
+        cls_model.eval()
+
+        if self.use_cuda:
+            if torch.cuda.device_count() > 1:
+                seg_model = nn.DataParallel(seg_model)
+                cls_model = nn.DataParallel(cls_model)
+
+            seg_model = seg_model.to(self.device)
+            cls_model = cls_model.to(self.device)
+
+        return seg_model, cls_model
+
+
+    def initSegmentationDl(self, series_uid):
+        seg_ds = Luna2dSegmentationDataset(
+                test_stride=10,
+                contextSlices_count=3,
+                series_uid=series_uid,
+            )
+        seg_dl = DataLoader(
+            seg_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return seg_dl
+
+    def initClassificationDl(self):
+        seg_ds = LunaClassificationDataset(
+                test_stride=10,
+                # contextSlices_count=3,
+                series_uid=self.cli_args.series_uid,
+            )
+        seg_dl = DataLoader(
+            seg_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return seg_dl
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        if self.cli_args.series_uid:
+            series_list = [self.cli_args.series_uid]
+        else:
+            series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
+
+        with torch.no_grad():
+            series_iter = enumerateWithEstimate(
+                series_list,
+                "Series",
+            )
+            for series_ndx, series_uid in series_iter:
+                seg_dl = self.initSegmentationDl(series_uid)
+                ct = getCt(series_uid)
+
+                output_ary = np.zeros_like(ct.ary, dtype=np.float32)
+
+                # testingMetrics_tensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset))
+                batch_iter = enumerateWithEstimate(
+                    seg_dl,
+                    "Seg " + series_uid,
+                    start_ndx=seg_dl.num_workers,
+                )
+                for batch_ndx, batch_tup in batch_iter:
+                    # self.computeSegmentationLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
+                    input_tensor, label_tensor, _series_list, ndx_list = batch_tup
+
+                    input_devtensor = input_tensor.to(self.device)
+
+                    prediction_devtensor = self.seg_model(input_devtensor)
+
+                    for i, sample_ndx in enumerate(ndx_list):
+                        output_ary[sample_ndx] = prediction_devtensor[i].detatch().cpu().numpy()
+
+                irc = (output_ary > 0.5).nonzero()
+                xyz = irc2xyz(irc, ct.origin_xyz, ct.vxSize_xyz, ct.direction_tup)
+
+                print(irc, xyz)
+
+
+        #
+        #         cls_dl = self.initClassificationDl(series_uid)
+        #
+        #         # testingMetrics_tensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset))
+        #         batch_iter = enumerateWithEstimate(
+        #             cls_dl,
+        #             "Cls " + series_uid,
+        #             start_ndx=cls_dl.num_workers,
+        #         )
+        #         for batch_ndx, batch_tup in batch_iter:
+        #             self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
+        #
+        #
+        #
+        #
+        #
+        #
+        #
+        #
+        # for epoch_ndx in range(1, self.cli_args.epochs + 1):
+        #     train_dl = self.initTrainDl(epoch_ndx)
+        #
+        #     log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
+        #         epoch_ndx,
+        #         self.cli_args.epochs,
+        #         len(train_dl),
+        #         len(test_dl),
+        #         self.cli_args.batch_size,
+        #         (torch.cuda.device_count() if self.use_cuda else 1),
+        #     ))
+        #
+        #     trainingMetrics_tensor = self.doTraining(epoch_ndx, train_dl)
+        #     if self.cli_args.segmentation:
+        #         self.logImages(epoch_ndx, train_dl, test_dl)
+        #
+        #     testingMetrics_tensor = self.doTesting(epoch_ndx, test_dl)
+        #     self.logMetrics(epoch_ndx, trainingMetrics_tensor, testingMetrics_tensor)
+        #
+        #     self.saveModel(epoch_ndx)
+        #
+        # if hasattr(self, 'trn_writer'):
+        #     self.trn_writer.close()
+        #     self.tst_writer.close()
+
+    def doTraining(self, epoch_ndx, train_dl):
+        self.model.train()
+        trainingMetrics_tensor = torch.zeros(METRICS_SIZE, len(train_dl.dataset))
+        train_dl.dataset.shuffleSamples()
+        batch_iter = enumerateWithEstimate(
+            train_dl,
+            "E{} Training".format(epoch_ndx),
+            start_ndx=train_dl.num_workers,
+        )
+        for batch_ndx, batch_tup in batch_iter:
+            self.optimizer.zero_grad()
+
+            if self.cli_args.segmentation:
+                loss_var = self.computeSegmentationLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
+            else:
+                loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
+
+            if loss_var is not None:
+                loss_var.backward()
+                self.optimizer.step()
+            del loss_var
+
+        self.totalTrainingSamples_count += trainingMetrics_tensor.size(1)
+
+        return trainingMetrics_tensor
+
+    def doTesting(self, epoch_ndx, test_dl):
+        with torch.no_grad():
+            self.model.eval()
+            testingMetrics_tensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset))
+            batch_iter = enumerateWithEstimate(
+                test_dl,
+                "E{} Testing ".format(epoch_ndx),
+                start_ndx=test_dl.num_workers,
+            )
+            for batch_ndx, batch_tup in batch_iter:
+                if self.cli_args.segmentation:
+                    self.computeSegmentationLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
+                else:
+                    self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
+
+        return testingMetrics_tensor
+
+    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_tensor):
+        input_tensor, label_tensor, _series_list, _center_list = batch_tup
+
+        input_devtensor = input_tensor.to(self.device)
+        label_devtensor = label_tensor.to(self.device)
+
+        prediction_devtensor = self.model(input_devtensor)
+        loss_devtensor = nn.MSELoss(reduction='none')(prediction_devtensor, label_devtensor)
+
+        start_ndx = batch_ndx * batch_size
+        end_ndx = start_ndx + label_tensor.size(0)
+        metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_tensor
+        metrics_tensor[METRICS_PRED_NDX, start_ndx:end_ndx] = prediction_devtensor.to('cpu')
+        metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_devtensor.to('cpu')
+
+        # TODO: replace with torch.autograd.detect_anomaly
+        # assert np.isfinite(metrics_tensor).all()
+
+        return loss_devtensor.mean()
+
+    def computeSegmentationLoss(self, batch_ndx, batch_tup, batch_size, metrics_tensor):
+        input_tensor, label_tensor, _series_list, _start_list = batch_tup
+
+        # if label_tensor.max() < 0.5:
+        #     return None
+
+        input_devtensor = input_tensor.to(self.device)
+        label_devtensor = label_tensor.to(self.device)
+
+        prediction_devtensor = self.model(input_devtensor)
+
+        # assert prediction_devtensor.is_contiguous()
+
+        start_ndx = batch_ndx * batch_size
+        end_ndx = start_ndx + label_tensor.size(0)
+        max2 = lambda t: t.view(t.size(0), -1).max(dim=1)[0]
+        intersectionSum = lambda a, b: (a * b.to(torch.float32)).view(a.size(0), -1).sum(dim=1)
+
+        diceLoss_devtensor = self.diceLoss(label_devtensor, prediction_devtensor)
+
+        with torch.no_grad():
+
+            boolPrediction_tensor = prediction_devtensor.to('cpu') > 0.5
+
+            metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = max2(label_tensor[:,0])
+            metrics_tensor[METRICS_MFOUND_NDX, start_ndx:end_ndx] = (max2(label_tensor[:, 0] * boolPrediction_tensor[:, 1].to(torch.float32)) > 0.5)
+
+            metrics_tensor[METRICS_MOK_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,0],  torch.max(boolPrediction_tensor, dim=1)[0])
+
+            metrics_tensor[METRICS_MTP_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,0],  boolPrediction_tensor[:,0])
+            metrics_tensor[METRICS_MFN_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,0], ~boolPrediction_tensor[:,0])
+            metrics_tensor[METRICS_MFP_NDX, start_ndx:end_ndx] = intersectionSum(1 - label_tensor[:,0],  boolPrediction_tensor[:,0])
+
+            metrics_tensor[METRICS_BTP_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,1],  boolPrediction_tensor[:,1])
+            metrics_tensor[METRICS_BFN_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,1], ~boolPrediction_tensor[:,1])
+            metrics_tensor[METRICS_BFP_NDX, start_ndx:end_ndx] = intersectionSum(1 - label_tensor[:,1],  boolPrediction_tensor[:,1])
+
+            diceLoss_tensor = diceLoss_devtensor.to('cpu')
+            metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_tensor
+
+            malLoss_devtensor = self.diceLoss(label_devtensor[:,0], prediction_devtensor[:,0])
+            malLoss_tensor = malLoss_devtensor.to('cpu')#.unsqueeze(1)
+            metrics_tensor[METRICS_MAL_LOSS_NDX, start_ndx:end_ndx] = malLoss_tensor
+
+            benLoss_devtensor = self.diceLoss(label_devtensor[:,1], prediction_devtensor[:,1])
+            benLoss_tensor = benLoss_devtensor.to('cpu')#.unsqueeze(1)
+            metrics_tensor[METRICS_BEN_LOSS_NDX, start_ndx:end_ndx] = benLoss_tensor
+
+            # lungLoss_devtensor = self.diceLoss(label_devtensor[:,2], prediction_devtensor[:,2])
+            # lungLoss_tensor = lungLoss_devtensor.to('cpu').unsqueeze(1)
+            # metrics_tensor[METRICS_LUNG_LOSS_NDX, start_ndx:end_ndx] = lungLoss_tensor
+
+        # TODO: replace with torch.autograd.detect_anomaly
+        # assert np.isfinite(metrics_tensor).all()
+
+        # return nn.MSELoss()(prediction_devtensor, label_devtensor)
+
+        return diceLoss_devtensor.mean()
+        # return self.diceLoss(label_devtensor[:,0], prediction_devtensor[:,0]).mean()
+
+    def diceLoss(self, label_devtensor, prediction_devtensor, epsilon=0.01):
+        # sum2 = lambda t: t.sum([1,2,3,4])
+        sum2 = lambda t: t.view(t.size(0), -1).sum(dim=1)
+        # max2 = lambda t: t.view(t.size(0), -1).max(dim=1)[0]
+
+        diceCorrect_devtensor = sum2(prediction_devtensor * label_devtensor)
+        dicePrediction_devtensor = sum2(prediction_devtensor)
+        diceLabel_devtensor = sum2(label_devtensor)
+        epsilon_devtensor = torch.ones_like(diceCorrect_devtensor) * epsilon
+        diceLoss_devtensor = 1 - (2 * diceCorrect_devtensor + epsilon_devtensor) / (dicePrediction_devtensor + diceLabel_devtensor + epsilon_devtensor)
+
+        return diceLoss_devtensor
+
+
+
+    def logImages(self, epoch_ndx, train_dl, test_dl):
+        if epoch_ndx > 0: # TODO revert
+            self.initTensorboardWriters()
+
+            for mode_str, dl in [('trn', train_dl), ('tst', test_dl)]:
+                for i, series_uid in enumerate(sorted(dl.dataset.series_list)[:12]):
+                    ct = getCt(series_uid)
+                    noduleInfo_tup = (ct.malignantInfo_list or ct.benignInfo_list)[0]
+                    center_irc = xyz2irc(noduleInfo_tup.center_xyz, ct.origin_xyz, ct.vxSize_xyz, ct.direction_tup)
+
+                    sample_tup = dl.dataset[(series_uid, int(center_irc.index))]
+                    input_tensor = sample_tup[0].unsqueeze(0)
+                    label_tensor = sample_tup[1].unsqueeze(0)
+
+                    input_devtensor = input_tensor.to(self.device)
+                    label_devtensor = label_tensor.to(self.device)
+
+                    prediction_devtensor = self.model(input_devtensor)
+                    prediction_ary = prediction_devtensor.to('cpu').detach().numpy()
+
+                    image_ary = np.zeros((512, 512, 3), dtype=np.float32)
+                    image_ary[:,:,:] = (input_tensor[0,2].numpy().reshape((512,512,1))) * 0.25
+                    image_ary[:,:,0] += prediction_ary[0,0] * 0.5
+                    image_ary[:,:,1] += prediction_ary[0,1] * 0.25
+                    # image_ary[:,:,2] += prediction_ary[0,2] * 0.5
+
+                    # log.debug([image_ary.__array_interface__['typestr']])
+
+                    # image_ary = (image_ary * 255).astype(np.uint8)
+
+                    # log.debug([image_ary.__array_interface__['typestr']])
+
+                    writer = getattr(self, mode_str + '_writer')
+                    writer.add_image('{}/{}_pred'.format(mode_str, i), image_ary, self.totalTrainingSamples_count)
+
+                    if epoch_ndx == 1:
+                        label_ary = label_tensor.numpy()
+
+                        image_ary = np.zeros((512, 512, 3), dtype=np.float32)
+                        image_ary[:,:,:] = (input_tensor[0,2].numpy().reshape((512,512,1))) * 0.25
+                        image_ary[:,:,0] += label_ary[0,0] * 0.5
+                        image_ary[:,:,1] += label_ary[0,1] * 0.25
+                        image_ary[:,:,2] += (input_tensor[0,-1].numpy() - (label_ary[0,0].astype(np.bool) | label_ary[0,1].astype(np.bool))) * 0.25
+
+                        # log.debug([image_ary.__array_interface__['typestr']])
+
+                        image_ary = (image_ary * 255).astype(np.uint8)
+
+                        # log.debug([image_ary.__array_interface__['typestr']])
+
+                        writer = getattr(self, mode_str + '_writer')
+                        writer.add_image('{}/{}_label'.format(mode_str, i), image_ary, self.totalTrainingSamples_count)
+
+
+    def logMetrics(self,
+                   epoch_ndx,
+                   trainingMetrics_tensor,
+                   testingMetrics_tensor,
+                   classificationThreshold_float=0.5,
+                   ):
+        log.info("E{} {}".format(
+            epoch_ndx,
+            type(self).__name__,
+        ))
+
+
+        for mode_str, metrics_tensor in [('trn', trainingMetrics_tensor), ('tst', testingMetrics_tensor)]:
+            metrics_ary = metrics_tensor.cpu().detach().numpy()
+            sum_ary = metrics_ary.sum(axis=1)
+            assert np.isfinite(metrics_ary).all()
+
+            malLabel_mask = metrics_ary[METRICS_LABEL_NDX] > classificationThreshold_float
+            malFound_mask = metrics_ary[METRICS_MFOUND_NDX] > classificationThreshold_float
+
+            # malLabel_mask = ~benLabel_mask
+            # malPred_mask = ~benPred_mask
+
+            benLabel_count = sum_ary[METRICS_BTP_NDX] + sum_ary[METRICS_BFN_NDX]
+            malLabel_count = sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]
+
+            trueNeg_count = benCorrect_count = sum_ary[METRICS_BTP_NDX]
+            truePos_count = malCorrect_count = sum_ary[METRICS_MTP_NDX]
+#
+#             falsePos_count = benLabel_count - benCorrect_count
+#             falseNeg_count = malLabel_count - malCorrect_count
+
+
+            metrics_dict = {}
+            metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
+            # metrics_dict['loss/msk'] = metrics_ary[METRICS_MASKLOSS_NDX].mean()
+            # metrics_dict['loss/mal'] = metrics_ary[METRICS_MALLOSS_NDX].mean()
+            # metrics_dict['loss/lng'] = metrics_ary[METRICS_LUNG_LOSS_NDX, benLabel_mask].mean()
+            metrics_dict['loss/mal'] = metrics_ary[METRICS_MAL_LOSS_NDX].mean()
+            metrics_dict['loss/ben'] = metrics_ary[METRICS_BEN_LOSS_NDX].mean()
+
+            metrics_dict['flagged/all'] = sum_ary[METRICS_MOK_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
+            metrics_dict['flagged/slices'] = (malLabel_mask & malFound_mask).sum() / malLabel_mask.sum() * 100
+
+            metrics_dict['correct/mal'] = sum_ary[METRICS_MTP_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
+            metrics_dict['correct/ben'] = sum_ary[METRICS_BTP_NDX] / (sum_ary[METRICS_BTP_NDX] + sum_ary[METRICS_BFN_NDX]) * 100
+
+            precision = metrics_dict['pr/precision'] = sum_ary[METRICS_MTP_NDX] / ((sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFP_NDX]) or 1)
+            recall    = metrics_dict['pr/recall']    = sum_ary[METRICS_MTP_NDX] / ((sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) or 1)
+
+            metrics_dict['pr/f1_score'] = 2 * (precision * recall) / ((precision + recall) or 1)
+
+            log.info(("E{} {:8} "
+                     + "{loss/all:.4f} loss, "
+                     + "{flagged/all:-5.1f}% pixels flagged, "
+                     + "{flagged/slices:-5.1f}% slices flagged, "
+                     + "{pr/precision:.4f} precision, "
+                     + "{pr/recall:.4f} recall, "
+                     + "{pr/f1_score:.4f} f1 score"
+                      ).format(
+                epoch_ndx,
+                mode_str,
+                **metrics_dict,
+            ))
+            log.info(("E{} {:8} "
+                     + "{loss/mal:.4f} loss, "
+                     + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
+            ).format(
+                epoch_ndx,
+                mode_str + '_mal',
+                malCorrect_count=malCorrect_count,
+                malLabel_count=malLabel_count,
+                **metrics_dict,
+            ))
+            log.info(("E{} {:8} "
+                     + "{loss/ben:.4f} loss, "
+                     + "{correct/ben:-5.1f}% correct ({benCorrect_count:} of {benLabel_count:})"
+            ).format(
+                epoch_ndx,
+                mode_str + '_msk',
+                benCorrect_count=benCorrect_count,
+                benLabel_count=benLabel_count,
+                **metrics_dict,
+            ))
+
+            if epoch_ndx > 0: # TODO revert
+                self.initTensorboardWriters()
+                writer = getattr(self, mode_str + '_writer')
+
+                for key, value in metrics_dict.items():
+                    writer.add_scalar('seg_' + key, value, self.totalTrainingSamples_count)
+
+#                 writer.add_pr_curve(
+#                     'pr',
+#                     metrics_ary[METRICS_LABEL_NDX],
+#                     metrics_ary[METRICS_PRED_NDX],
+#                     self.totalTrainingSamples_count,
+#                 )
+
+#                 benHist_mask = benLabel_mask & (metrics_ary[METRICS_PRED_NDX] > 0.01)
+#                 malHist_mask = malLabel_mask & (metrics_ary[METRICS_PRED_NDX] < 0.99)
+#
+#                 bins = [x/50.0 for x in range(51)]
+#                 writer.add_histogram(
+#                     'is_ben',
+#                     metrics_ary[METRICS_PRED_NDX, benHist_mask],
+#                     self.totalTrainingSamples_count,
+#                     bins=bins,
+#                 )
+#                 writer.add_histogram(
+#                     'is_mal',
+#                     metrics_ary[METRICS_PRED_NDX, malHist_mask],
+#                     self.totalTrainingSamples_count,
+#                     bins=bins,
+#                 )
+
+    def saveModel(self, epoch_ndx):
+        file_path = os.path.join('data', 'models', self.cli_args.tb_prefix, '{}_{}.{}.state'.format(self.time_str, self.cli_args.comment, self.totalTrainingSamples_count))
+
+        os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
+
+        state = {
+            'model_state': self.model.state_dict(),
+            'model_name': type(self.model).__name__,
+            'optimizer_state' : self.optimizer.state_dict(),
+            'optimizer_name': type(self.optimizer).__name__,
+            'epoch': epoch_ndx,
+            'totalTrainingSamples_count': self.totalTrainingSamples_count,
+            # 'resumed_from': self.cli_args.resume,
+        }
+        torch.save(state, file_path)
+
+        log.debug("Saved model params to {}".format(file_path))
+
+
+if __name__ == '__main__':
+    sys.exit(LunaDiagnoseApp().main() or 0)

+ 316 - 0
p2ch11/dsets.py

@@ -0,0 +1,316 @@
+import copy
+import csv
+import functools
+import glob
+import math
+import os
+import random
+
+from collections import namedtuple
+
+import SimpleITK as sitk
+
+import numpy as np
+import torch
+import torch.cuda
+from torch.utils.data import Dataset
+import torch.nn as nn
+import torch.nn.functional as F
+
+import numpy as np
+
+from util.disk import getCache
+from util.util import XyzTuple, xyz2irc
+from util.logconf import logging
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+raw_cache = getCache('part2ch11_raw')
+
+NoduleInfoTuple = namedtuple('NoduleInfoTuple', 'isMalignant_bool, diameter_mm, series_uid, center_xyz')
+
+@functools.lru_cache(1)
+def getNoduleInfoList(requireDataOnDisk_bool=True):
+    # We construct a set with all series_uids that are present on disk.
+    # This will let us use the data, even if we haven't downloaded all of
+    # the subsets yet.
+    mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
+    dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
+
+    diameter_dict = {}
+    with open('data/part2/luna/annotations.csv', "r") as f:
+        for row in list(csv.reader(f))[1:]:
+            series_uid = row[0]
+            annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
+            annotationDiameter_mm = float(row[4])
+
+            diameter_dict.setdefault(series_uid, []).append((annotationCenter_xyz, annotationDiameter_mm))
+
+    noduleInfo_list = []
+    with open('data/part2/luna/candidates.csv', "r") as f:
+        for row in list(csv.reader(f))[1:]:
+            series_uid = row[0]
+
+            if series_uid not in dataPresentOnDisk_set and requireDataOnDisk_bool:
+                continue
+
+            isMalignant_bool = bool(int(row[4]))
+            candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
+
+            candidateDiameter_mm = 0.0
+            for annotationCenter_xyz, annotationDiameter_mm in diameter_dict.get(series_uid, []):
+                for i in range(3):
+                    delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
+                    if delta_mm > annotationDiameter_mm / 4:
+                        break
+                else:
+                    candidateDiameter_mm = annotationDiameter_mm
+                    break
+
+            noduleInfo_list.append(NoduleInfoTuple(isMalignant_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))
+
+    noduleInfo_list.sort(reverse=True)
+    return noduleInfo_list
+
+class Ct(object):
+    def __init__(self, series_uid):
+        mhd_path = glob.glob('data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
+
+        ct_mhd = sitk.ReadImage(mhd_path)
+        ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
+
+        # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
+        # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
+        # This gets rid of negative density stuff used to indicate out-of-FOV
+        ct_ary[ct_ary < -1000] = -1000
+
+        # This nukes any weird hotspots and clamps bone down
+        ct_ary[ct_ary > 1000] = 1000
+
+        self.series_uid = series_uid
+        self.ary = ct_ary
+
+        self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
+        self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
+        self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
+
+    def getRawNodule(self, center_xyz, width_irc):
+        center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
+
+        slice_list = []
+        for axis, center_val in enumerate(center_irc):
+            start_ndx = int(round(center_val - width_irc[axis]/2))
+            end_ndx = int(start_ndx + width_irc[axis])
+
+            assert center_val >= 0 and center_val < self.ary.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
+
+            if start_ndx < 0:
+                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
+                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
+                start_ndx = 0
+                end_ndx = int(width_irc[axis])
+
+            if end_ndx > self.ary.shape[axis]:
+                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
+                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
+                end_ndx = self.ary.shape[axis]
+                start_ndx = int(self.ary.shape[axis] - width_irc[axis])
+
+            slice_list.append(slice(start_ndx, end_ndx))
+
+        ct_chunk = self.ary[tuple(slice_list)]
+
+        return ct_chunk, center_irc
+
+
+@functools.lru_cache(1, typed=True)
+def getCt(series_uid):
+    return Ct(series_uid)
+
+@raw_cache.memoize(typed=True)
+def getCtRawNodule(series_uid, center_xyz, width_irc):
+    ct = getCt(series_uid)
+    ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
+    return ct_chunk, center_irc
+
+def getCtAugmentedNodule(
+        augmentation_dict,
+        series_uid, center_xyz, width_irc,
+        use_cache=True):
+    if use_cache:
+        ct_chunk, center_irc = getCtRawNodule(series_uid, center_xyz, width_irc)
+    else:
+        ct = getCt(series_uid)
+        ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
+
+    ct_tensor = torch.tensor(ct_chunk).unsqueeze(0).unsqueeze(0).to(torch.float32)
+
+    transform_tensor = torch.eye(4).to(torch.float64)
+    # ... <1>
+
+    for i in range(3):
+        if 'flip' in augmentation_dict:
+            if random.random() > 0.5:
+                transform_tensor[i,i] *= -1
+
+        if 'offset' in augmentation_dict:
+            offset_float = augmentation_dict['offset']
+            random_float = (random.random() * 2 - 1)
+            transform_tensor[3,i] = offset_float * random_float
+
+        if 'scale' in augmentation_dict:
+            scale_float = augmentation_dict['scale']
+            random_float = (random.random() * 2 - 1)
+            transform_tensor[i,i] *= 1.0 + scale_float * random_float
+
+
+    if 'rotate' in augmentation_dict:
+        angle_rad = random.random() * math.pi * 2
+        s = math.sin(angle_rad)
+        c = math.cos(angle_rad)
+
+        rotation_tensor = torch.tensor([
+            [c, -s, 0, 0],
+            [s, c, 0, 0],
+            [0, 0, 1, 0],
+            [0, 0, 0, 1],
+        ], dtype=torch.float64)
+
+        transform_tensor @= rotation_tensor
+
+    affine_tensor = F.affine_grid(
+            transform_tensor[:3].unsqueeze(0).to(torch.float32),
+            ct_tensor.size(),
+        )
+
+    augmented_chunk = F.grid_sample(
+            ct_tensor,
+            affine_tensor,
+            padding_mode='border'
+        ).to('cpu')
+
+    if 'noise' in augmentation_dict:
+        noise_tensor = torch.randn_like(augmented_chunk)
+        noise_tensor *= augmentation_dict['noise']
+
+        augmented_chunk += noise_tensor
+
+    return augmented_chunk[0], center_irc
+
+
+class LunaDataset(Dataset):
+    def __init__(self,
+                 test_stride=0,
+                 isTestSet_bool=None,
+                 series_uid=None,
+                 sortby_str='random',
+                 ratio_int=0,
+                 augmentation_dict=None,
+                 noduleInfo_list=None,
+            ):
+        self.ratio_int = ratio_int
+        self.augmentation_dict = augmentation_dict
+
+        if noduleInfo_list:
+            self.noduleInfo_list = copy.copy(noduleInfo_list)
+            self.use_cache = False
+        else:
+            self.noduleInfo_list = copy.copy(getNoduleInfoList())
+            self.use_cache = True
+
+        if series_uid:
+            self.noduleInfo_list = [x for x in self.noduleInfo_list if x.series_uid == series_uid]
+
+        if test_stride > 1:
+            if isTestSet_bool:
+                self.noduleInfo_list = self.noduleInfo_list[::test_stride]
+            else:
+                del self.noduleInfo_list[::test_stride]
+
+        if sortby_str == 'random':
+            random.shuffle(self.noduleInfo_list)
+        elif sortby_str == 'series_uid':
+            self.noduleInfo_list.sort(key=lambda x: (x[2], x[3])) # sorting by series_uid, center_xyz)
+        elif sortby_str == 'malignancy_size':
+            pass
+        else:
+            raise Exception("Unknown sort: " + repr(sortby_str))
+
+        self.benign_list = [nt for nt in self.noduleInfo_list if not nt.isMalignant_bool]
+        self.malignant_list = [nt for nt in self.noduleInfo_list if nt.isMalignant_bool]
+
+        log.info("{!r}: {} {} samples, {} ben, {} mal, {} ratio".format(
+            self,
+            len(self.noduleInfo_list),
+            "testing" if isTestSet_bool else "training",
+            len(self.benign_list),
+            len(self.malignant_list),
+            '{}:1'.format(self.ratio_int) if self.ratio_int else 'unbalanced'
+        ))
+
+    def shuffleSamples(self):
+        if self.ratio_int:
+            random.shuffle(self.benign_list)
+            random.shuffle(self.malignant_list)
+
+    def __len__(self):
+        if self.ratio_int:
+            return 200000
+        else:
+            return len(self.noduleInfo_list)
+
+    def __getitem__(self, ndx):
+        if self.ratio_int:
+            malignant_ndx = ndx // (self.ratio_int + 1)
+
+            if ndx % (self.ratio_int + 1):
+                benign_ndx = ndx - 1 - malignant_ndx
+                benign_ndx %= len(self.benign_list)
+                nodule_tup = self.benign_list[benign_ndx]
+            else:
+                malignant_ndx %= len(self.malignant_list)
+                nodule_tup = self.malignant_list[malignant_ndx]
+        else:
+            nodule_tup = self.noduleInfo_list[ndx]
+
+        width_irc = (24, 48, 48)
+
+        if self.augmentation_dict:
+            nodule_t, center_irc = getCtAugmentedNodule(
+                self.augmentation_dict,
+                nodule_tup.series_uid,
+                nodule_tup.center_xyz,
+                width_irc,
+                self.use_cache,
+            )
+        elif self.use_cache:
+            nodule_ary, center_irc = getCtRawNodule(
+                nodule_tup.series_uid,
+                nodule_tup.center_xyz,
+                width_irc,
+            )
+            nodule_t = torch.from_numpy(nodule_ary).to(torch.float32)
+            nodule_t = nodule_t.unsqueeze(0)
+        else:
+            ct = getCt(nodule_tup.series_uid)
+            nodule_ary, center_irc = ct.getRawNodule(
+                nodule_tup.center_xyz,
+                width_irc,
+            )
+            nodule_t = torch.from_numpy(nodule_ary).to(torch.float32)
+            nodule_t = nodule_t.unsqueeze(0)
+
+        malignant_tensor = torch.tensor([
+                not nodule_tup.isMalignant_bool,
+                nodule_tup.isMalignant_bool
+            ],
+            dtype=torch.long,
+        )
+
+        return nodule_t, malignant_tensor, nodule_tup.series_uid, center_irc
+
+
+

+ 63 - 0
p2ch11/model.py

@@ -0,0 +1,63 @@
+import math
+
+import torch.nn as nn
+
+from util.logconf import logging
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+
+class LunaModel(nn.Module):
+    def __init__(self, layer_count=4, in_channels=1, conv_channels=8):
+        super().__init__()
+
+        self.input_batchnorm = nn.BatchNorm2d(1)
+
+        layer_list = []
+        for layer_ndx in range(layer_count):
+            layer_list += [
+                nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True),
+                nn.ReLU(inplace=True),
+                nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=True),
+                nn.ReLU(inplace=True),
+                nn.MaxPool3d(2, 2),
+            ]
+
+            in_channels = conv_channels
+            conv_channels *= 2
+
+        self.convAndPool_seq = nn.Sequential(*layer_list)
+        self.fullyConnected_layer = nn.Linear(576, 2)
+        self.final = nn.Softmax(dim=1)
+
+        self._init_weights()
+
+    def _init_weights(self):
+        # see also https://github.com/pytorch/pytorch/issues/18182
+        for m in self.modules():
+            if type(m) in {
+                nn.Conv2d,
+                nn.Conv3d,
+                nn.ConvTranspose2d,
+                nn.ConvTranspose3d,
+                nn.Linear,
+            }:
+                # log.debug(m)
+                # nn.init.kaiming_normal_(m.weight.data, mode='fan_out', a=0)
+                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out', nonlinearity='relu')
+                if m.bias is not None:
+                    fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
+                    bound = 1 / math.sqrt(fan_out)
+                    nn.init.normal_(m.bias, -bound, bound)
+
+    def forward(self, input_batch):
+        bn_output = self.input_batchnorm(input_batch)
+        conv_output = self.convAndPool_seq(bn_output)
+        conv_flat = conv_output.view(conv_output.size(0), -1)
+        classifier_output = self.fullyConnected_layer(conv_flat)
+
+        return classifier_output, self.final(classifier_output)
+

+ 328 - 0
p2ch11/model_segmentation.py

@@ -0,0 +1,328 @@
+import torch
+from torch import nn as nn
+
+from util.logconf import logging
+from util.unet import UNet
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+# torch.backends.cudnn.enabled = False
+
+class UNetWrapper(nn.Module):
+    def __init__(self, **kwargs):
+        super().__init__()
+
+        self.batchnorm = nn.BatchNorm2d(kwargs['in_channels'])
+        self.unet = UNet(**kwargs)
+        self.hardtanh = nn.Hardtanh(min_val=0, max_val=1)
+
+    def forward(self, input):
+        bn_output = self.batchnorm(input)
+        un_output = self.unet(bn_output)
+        ht_output = self.hardtanh(un_output)
+
+        return ht_output
+
+
+
+class Simple2dSegmentationModel(nn.Module):
+    def __init__(self, layers, in_channels, conv_channels, final_channels):
+        super().__init__()
+        self.layers = layers
+
+        self.in_channels = in_channels
+        self.conv_channels = conv_channels
+        self.final_channels = final_channels
+
+        layer_list = [
+            nn.Conv2d(self.in_channels, self.conv_channels, kernel_size=3, padding=1),
+            nn.BatchNorm2d(self.conv_channels),
+            # nn.GroupNorm(1, self.conv_channels),
+            # nn.ReLU(inplace=True),
+            nn.LeakyReLU(inplace=True),
+        ]
+
+        for i in range(self.layers):
+            layer_list.extend([
+                nn.Conv2d(self.conv_channels, self.conv_channels, kernel_size=3, padding=1),
+                nn.BatchNorm2d(self.conv_channels),
+                # nn.GroupNorm(1, self.conv_channels),
+                # nn.ReLU(inplace=True),
+                nn.LeakyReLU(inplace=True),
+            ])
+
+        layer_list.extend([
+            nn.Conv2d(self.conv_channels, self.final_channels, kernel_size=1, bias=True),
+            nn.Hardtanh(min_val=0, max_val=1),
+        ])
+
+        self.layer_seq = nn.Sequential(*layer_list)
+
+
+    def forward(self, in_data):
+        return self.layer_seq(in_data)
+
+
+class Dense2dSegmentationModel(nn.Module):
+    def __init__(self, layers, input_channels, conv_channels, bottleneck_channels, final_channels):
+        super().__init__()
+        self.layers = layers
+
+        self.input_channels = input_channels
+        self.conv_channels = conv_channels
+        self.bottleneck_channels = bottleneck_channels
+        self.final_channels = final_channels
+
+        self.layer_list = nn.ModuleList()
+
+        for i in range(layers):
+            self.layer_list.append(
+                Dense2dSegmentationBlock(
+                    input_channels + bottleneck_channels * i,
+                    conv_channels,
+                    bottleneck_channels,
+                )
+            )
+
+        self.layer_list.append(
+            Dense2dSegmentationBlock(
+                input_channels + bottleneck_channels * layers,
+                conv_channels,
+                bottleneck_channels,
+                final_channels,
+            )
+        )
+
+        self.htanh_layer = nn.Hardtanh(min_val=0, max_val=1)
+
+    def forward(self, input_tensor):
+        concat_list = [input_tensor]
+        for layer_block in self.layer_list:
+            layer_output = layer_block(torch.cat(concat_list, dim=1))
+            concat_list.append(layer_output)
+
+        return self.htanh_layer(concat_list[-1])
+
+
+class Dense2dSegmentationBlock(nn.Module):
+    def __init__(self, input_channels, conv_channels, bottleneck_channels, final_channels=None):
+        super().__init__()
+
+        self.input_channels = input_channels
+        self.conv_channels = conv_channels
+        self.bottleneck_channels = bottleneck_channels
+        self.final_channels = final_channels or bottleneck_channels
+
+        self.conv1_seq = nn.Sequential(
+            nn.Conv2d(self.input_channels, self.bottleneck_channels, kernel_size=1),
+            nn.Conv2d(self.bottleneck_channels, self.conv_channels, kernel_size=3, padding=1),
+            nn.Conv2d(self.conv_channels, self.bottleneck_channels, kernel_size=1),
+            # nn.BatchNorm2d(self.conv_channels),
+            nn.GroupNorm(1, self.bottleneck_channels),
+            # nn.ReLU(inplace=True),
+            nn.LeakyReLU(inplace=True),
+        )
+
+        self.conv2_seq = nn.Sequential(
+            nn.Conv2d(self.input_channels + self.bottleneck_channels, self.bottleneck_channels, kernel_size=1),
+            nn.Conv2d(self.bottleneck_channels, self.conv_channels, kernel_size=3, padding=1),
+            nn.Conv2d(self.conv_channels, self.final_channels, kernel_size=1),
+            # nn.BatchNorm2d(self.conv_channels),
+            nn.GroupNorm(1, self.final_channels),
+            # nn.ReLU(inplace=True),
+            nn.LeakyReLU(inplace=True),
+        )
+
+    def forward(self, input_tensor):
+        conv1_tensor = self.conv1_seq(input_tensor)
+        conv2_tensor = self.conv2_seq(torch.cat([input_tensor, conv1_tensor], dim=1))
+
+        return conv2_tensor
+
+
+class SegmentationModel(nn.Module):
+    def __init__(self, depth, in_channels, tail_channels=None, out_channels=None, final_channels=None):
+        super().__init__()
+        self.depth = depth
+
+        # self.in_size = in_size
+        # self.tailOut_size = in_size #self.in_size - 4
+        # self.headIn_size = in_size #None
+        # self.out_size = in_size #None
+
+        self.in_channels = in_channels
+        self.tailOut_channels = tail_channels or in_channels * 2
+        self.headIn_channels = None
+        self.out_channels = out_channels or self.tailOut_channels
+        self.final_channels = final_channels
+
+        # assert in_size % 2 == 0, repr([in_size, depth])
+
+        self.tail_seq = nn.Sequential(
+            nn.ReplicationPad3d(2),
+            nn.Conv3d(self.in_channels, self.tailOut_channels, 3),
+            nn.GroupNorm(1, self.tailOut_channels),
+            nn.ReLU(inplace=True),
+            nn.Conv3d(self.tailOut_channels, self.tailOut_channels, 3),
+            nn.GroupNorm(1, self.tailOut_channels),
+            nn.ReLU(inplace=True),
+        )
+
+        if depth:
+            self.downsample_layer = nn.MaxPool3d(kernel_size=2, stride=2)
+            self.child_layer = SegmentationModel(depth - 1, self.tailOut_channels)
+
+            self.headIn_channels = self.in_channels + self.tailOut_channels + self.child_layer.out_channels
+            # self.headIn_size = self.child_layer.out_size * 2
+            # self.out_size = self.headIn_size #- 4
+
+            # self.upsample_layer = nn.Upsample(scale_factor=2, mode='trilinear')
+        else:
+            self.downsample_layer = None
+            self.child_layer = None
+            # self.upsample_layer = None
+
+            self.headIn_channels = self.in_channels + self.tailOut_channels
+            # self.headIn_size = self.tailOut_size
+            # self.out_size = self.headIn_size #- 4
+
+        self.head_seq = nn.Sequential(
+            nn.ReplicationPad3d(2),
+            nn.Conv3d(self.headIn_channels, self.out_channels, 3),
+            nn.GroupNorm(1, self.out_channels),
+            nn.ReLU(inplace=True),
+            nn.Conv3d(self.out_channels, self.out_channels, 3),
+            nn.GroupNorm(1, self.out_channels),
+            nn.ReLU(inplace=True),
+        )
+
+        if self.final_channels:
+            self.final_seq = nn.Sequential(
+                nn.ReplicationPad3d(1),
+                nn.Conv3d(self.out_channels, self.final_channels, 1),
+            )
+        else:
+            self.final_seq = None
+
+    def forward(self, in_data):
+
+        assert in_data.is_contiguous()
+
+        try:
+            tail_out = self.tail_seq(in_data)
+        except:
+            log.debug([in_data.size()])
+            raise
+
+        if self.downsample_layer:
+            down_out = self.downsample_layer(tail_out)
+            child_out = self.child_layer(down_out)
+            # up_out = self.upsample_layer(child_out)
+
+            up_out = nn.functional.interpolate(child_out, scale_factor=2, mode='trilinear')
+
+            # crop_int = (tail_out.size(-1) - up_out.size(-1)) // 2
+            # crop_out = tail_out[:, :, crop_int:-crop_int, crop_int:-crop_int, crop_int:-crop_int]
+            # combined_out = torch.cat([crop_out, up_out], 1)
+
+            combined_out = torch.cat([in_data, tail_out, up_out], 1)
+        else:
+            combined_out = torch.cat([in_data, tail_out], 1)
+
+        head_out = self.head_seq(combined_out)
+
+        if self.final_seq:
+            final_out = self.final_seq(head_out)
+            return final_out
+        else:
+            return head_out
+
+
+class DenseSegmentationModel(nn.Module):
+    def __init__(self, depth, in_channels, conv_channels, final_channels=None):
+        super().__init__()
+        self.depth = depth
+
+        self.in_channels = in_channels
+        self.conv_channels = conv_channels
+        self.final_channels = final_channels
+
+        self.convA_seq = nn.Sequential(
+            nn.Conv3d(self.in_channels, self.conv_channels // 4, 1),
+            nn.ReplicationPad3d(1),
+            nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
+            nn.BatchNorm3d(self.conv_channels),
+            nn.ReLU(inplace=True),
+        )
+
+        self.convB_seq = nn.Sequential(
+            nn.Conv3d(self.in_channels + self.conv_channels, self.conv_channels // 4, 1),
+            nn.ReplicationPad3d(1),
+            nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
+            nn.BatchNorm3d(self.conv_channels),
+            nn.ReLU(inplace=True),
+        )
+
+        if self.depth:
+            self.downsample_layer = nn.MaxPool3d(kernel_size=2, stride=2)
+            self.child_layer = SegmentationModel(depth - 1, self.conv_channels, self.conv_channels * 2)
+            self.upsample_layer = nn.Upsample(scale_factor=2, mode='trilinear')
+
+            self.convC_seq = nn.Sequential(
+                nn.Conv3d(self.in_channels + self.conv_channels * 3, self.conv_channels // 4, 1),
+                nn.ReplicationPad3d(1),
+                nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
+                nn.BatchNorm3d(self.conv_channels),
+                nn.ReLU(inplace=True),
+            )
+        else:
+            self.downsample_layer = None
+            self.child_layer = None
+            self.upsample_layer = None
+
+            self.convC_seq = nn.Sequential(
+                nn.Conv3d(self.in_channels + self.conv_channels, self.conv_channels // 4, 1),
+                nn.ReplicationPad3d(1),
+                nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
+                nn.BatchNorm3d(self.conv_channels),
+                nn.ReLU(inplace=True),
+            )
+
+        self.convD_seq = nn.Sequential(
+            nn.Conv3d(self.in_channels + self.conv_channels, self.conv_channels // 4, 1),
+            nn.ReplicationPad3d(1),
+            nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
+            nn.BatchNorm3d(self.conv_channels),
+            nn.ReLU(inplace=True),
+        )
+
+        if self.final_channels:
+            self.final_seq = nn.Sequential(
+                # nn.ReplicationPad3d(1),
+                nn.Conv3d(self.conv_channels, self.final_channels, 1),
+            )
+        else:
+            self.final_seq = None
+
+    def forward(self, data_in):
+        a_out = self.convA_seq(data_in)
+        b_out = self.convB_seq(torch.cat([data_in, a_out], 1))
+
+        if self.downsample_layer:
+            down_out = self.downsample_layer(b_out)
+            child_out = self.child_layer(down_out)
+            up_out = self.upsample_layer(child_out)
+
+            c_out = self.convC_seq(torch.cat([data_in, b_out, up_out], 1))
+        else:
+            c_out = self.convC_seq(torch.cat([data_in, b_out], 1))
+
+        d_out = self.convD_seq(torch.cat([data_in, c_out], 1))
+
+        if self.final_seq:
+            return self.final_seq(d_out)
+        else:
+            return d_out

+ 63 - 0
p2ch11/prepcache.py

@@ -0,0 +1,63 @@
+import argparse
+import sys
+
+import numpy as np
+
+import torch.nn as nn
+from torch.autograd import Variable
+from torch.optim import SGD
+from torch.utils.data import DataLoader
+
+from util.util import enumerateWithEstimate
+from .dsets import LunaDataset
+from util.logconf import logging
+from .model import LunaModel
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+log.setLevel(logging.INFO)
+# log.setLevel(logging.DEBUG)
+
+
+class LunaPrepCacheApp(object):
+    @classmethod
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=1024,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+
+        self.cli_args = parser.parse_args(sys_argv)
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        self.prep_dl = DataLoader(
+            LunaDataset(
+                sortby_str='series_uid',
+            ),
+            batch_size=self.cli_args.batch_size,
+            num_workers=self.cli_args.num_workers,
+        )
+
+        batch_iter = enumerateWithEstimate(
+            self.prep_dl,
+            "Stuffing cache",
+            start_ndx=self.prep_dl.num_workers,
+        )
+        for _ in batch_iter:
+            pass
+
+
+if __name__ == '__main__':
+    sys.exit(LunaPrepCacheApp().main() or 0)

+ 420 - 0
p2ch11/training.py

@@ -0,0 +1,420 @@
+import argparse
+import datetime
+import os
+import sys
+
+import numpy as np
+
+from tensorboardX import SummaryWriter
+
+import torch
+import torch.nn as nn
+from torch.optim import SGD
+from torch.utils.data import DataLoader
+
+from util.util import enumerateWithEstimate
+from .dsets import LunaDataset
+from util.logconf import logging
+from .model import LunaModel
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+log.setLevel(logging.INFO)
+# log.setLevel(logging.DEBUG)
+
+# Used for computeBatchLoss and logMetrics to index into metrics_tensor/metrics_ary
+METRICS_LABEL_NDX=0
+METRICS_PRED_NDX=1
+METRICS_LOSS_NDX=2
+METRICS_SIZE = 3
+
+class LunaTrainingApp(object):
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=32,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+        parser.add_argument('--epochs',
+            help='Number of epochs to train for',
+            default=1,
+            type=int,
+        )
+        parser.add_argument('--balanced',
+            help="Balance the training data to half benign, half malignant.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augmented',
+            help="Augment the training data.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-flip',
+            help="Augment the training data by randomly flipping the data left-right, up-down, and front-back.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-offset',
+            help="Augment the training data by randomly offsetting the data slightly along the X and Y axes.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-scale',
+            help="Augment the training data by randomly increasing or decreasing the size of the nodule.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-rotate',
+            help="Augment the training data by randomly rotating the data around the head-foot axis.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-noise',
+            help="Augment the training data by randomly adding noise to the data.",
+            action='store_true',
+            default=False,
+        )
+
+        parser.add_argument('--tb-prefix',
+            default='p2ch11',
+            help="Data prefix to use for Tensorboard run. Defaults to chapter.",
+        )
+        parser.add_argument('comment',
+            help="Comment suffix for Tensorboard run.",
+            nargs='?',
+            default='none',
+        )
+
+        self.cli_args = parser.parse_args(sys_argv)
+        self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
+
+        self.totalTrainingSamples_count = 0
+        self.trn_writer = None
+        self.tst_writer = None
+
+        self.augmentation_dict = {}
+        if self.cli_args.augmented or self.cli_args.augment_flip:
+            self.augmentation_dict['flip'] = True
+        if self.cli_args.augmented or self.cli_args.augment_offset:
+            self.augmentation_dict['offset'] = 0.1
+        if self.cli_args.augmented or self.cli_args.augment_scale:
+            self.augmentation_dict['scale'] = 0.2
+        if self.cli_args.augmented or self.cli_args.augment_rotate:
+            self.augmentation_dict['rotate'] = True
+        if self.cli_args.augmented or self.cli_args.augment_noise:
+            self.augmentation_dict['noise'] = 25.0
+
+        self.use_cuda = torch.cuda.is_available()
+        self.device = torch.device("cuda" if self.use_cuda else "cpu")
+
+        self.model = self.initModel()
+        self.optimizer = self.initOptimizer()
+
+
+    def initModel(self):
+        model = LunaModel()
+        if self.use_cuda:
+            if torch.cuda.device_count() > 1:
+                model = nn.DataParallel(model)
+            model = model.to(self.device)
+        return model
+
+    def initOptimizer(self):
+        return SGD(self.model.parameters(), lr=0.001, momentum=0.99)
+        # return Adam(self.model.parameters())
+
+    def initTrainDl(self):
+        train_ds = LunaDataset(
+            test_stride=10,
+            isTestSet_bool=False,
+            ratio_int=int(self.cli_args.balanced),
+            augmentation_dict=self.augmentation_dict,
+        )
+
+        train_dl = DataLoader(
+            train_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return train_dl
+
+    def initTestDl(self):
+        test_ds = LunaDataset(
+            test_stride=10,
+            isTestSet_bool=True,
+        )
+
+        test_dl = DataLoader(
+            test_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return test_dl
+
+    def initTensorboardWriters(self):
+        if self.trn_writer is None:
+            log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
+
+            self.trn_writer = SummaryWriter(log_dir=log_dir + '_trn_cls_' + self.cli_args.comment)
+            self.tst_writer = SummaryWriter(log_dir=log_dir + '_tst_cls_' + self.cli_args.comment)
+# eng::tb_writer[]
+
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        train_dl = self.initTrainDl()
+        test_dl = self.initTestDl()
+
+        self.initTensorboardWriters()
+        # self.logModelMetrics(self.model)
+
+        # best_score = 0.0
+
+        for epoch_ndx in range(1, self.cli_args.epochs + 1):
+
+            log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
+                epoch_ndx,
+                self.cli_args.epochs,
+                len(train_dl),
+                len(test_dl),
+                self.cli_args.batch_size,
+                (torch.cuda.device_count() if self.use_cuda else 1),
+            ))
+
+            trnMetrics_tensor = self.doTraining(epoch_ndx, train_dl)
+            self.logMetrics(epoch_ndx, 'trn', trnMetrics_tensor)
+
+            tstMetrics_tensor = self.doTesting(epoch_ndx, test_dl)
+            self.logMetrics(epoch_ndx, 'tst', tstMetrics_tensor)
+
+        if hasattr(self, 'trn_writer'):
+            self.trn_writer.close()
+            self.tst_writer.close()
+
+
+    def doTraining(self, epoch_ndx, train_dl):
+        self.model.train()
+        train_dl.dataset.shuffleSamples()
+        trainingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(train_dl.dataset)).to(self.device)
+        batch_iter = enumerateWithEstimate(
+            train_dl,
+            "E{} Training".format(epoch_ndx),
+            start_ndx=train_dl.num_workers,
+        )
+        for batch_ndx, batch_tup in batch_iter:
+            self.optimizer.zero_grad()
+
+            loss_var = self.computeBatchLoss(
+                batch_ndx,
+                batch_tup,
+                train_dl.batch_size,
+                trainingMetrics_devtensor
+            )
+
+            loss_var.backward()
+            self.optimizer.step()
+            del loss_var
+
+        self.totalTrainingSamples_count += trainingMetrics_devtensor.size(1)
+
+        return trainingMetrics_devtensor.to('cpu')
+
+
+    def doTesting(self, epoch_ndx, test_dl):
+        with torch.no_grad():
+            self.model.eval()
+            testingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset)).to(self.device)
+            batch_iter = enumerateWithEstimate(
+                test_dl,
+                "E{} Testing ".format(epoch_ndx),
+                start_ndx=test_dl.num_workers,
+            )
+            for batch_ndx, batch_tup in batch_iter:
+                self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_devtensor)
+
+        return testingMetrics_devtensor.to('cpu')
+
+
+
+    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_devtensor):
+        input_tensor, label_tensor, _series_list, _center_list = batch_tup
+
+        input_devtensor = input_tensor.to(self.device, non_blocking=True)
+        label_devtensor = label_tensor.to(self.device, non_blocking=True)
+
+        logits_devtensor, probability_devtensor = self.model(input_devtensor)
+
+        loss_func = nn.CrossEntropyLoss(reduction='none')
+        loss_devtensor = loss_func(logits_devtensor, label_devtensor[:,1])
+        start_ndx = batch_ndx * batch_size
+        end_ndx = start_ndx + label_tensor.size(0)
+
+        metrics_devtensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_devtensor[:,1]
+        metrics_devtensor[METRICS_PRED_NDX, start_ndx:end_ndx] = probability_devtensor[:,1]
+        metrics_devtensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_devtensor
+
+        return loss_devtensor.mean()
+
+
+    def logMetrics(
+            self,
+            epoch_ndx,
+            mode_str,
+            metrics_tensor,
+    ):
+        log.info("E{} {}".format(
+            epoch_ndx,
+            type(self).__name__,
+        ))
+
+        metrics_ary = metrics_tensor.cpu().detach().numpy()
+#         assert np.isfinite(metrics_ary).all()
+
+        benLabel_mask = metrics_ary[METRICS_LABEL_NDX] <= 0.5
+        benPred_mask = metrics_ary[METRICS_PRED_NDX] <= 0.5
+
+        malLabel_mask = ~benLabel_mask
+        malPred_mask = ~benPred_mask
+
+        benLabel_count = benLabel_mask.sum()
+        malLabel_count = malLabel_mask.sum()
+
+        trueNeg_count = benCorrect_count = (benLabel_mask & benPred_mask).sum()
+        truePos_count = malCorrect_count = (malLabel_mask & malPred_mask).sum()
+
+        falsePos_count = benLabel_count - benCorrect_count
+        falseNeg_count = malLabel_count - malCorrect_count
+
+        metrics_dict = {}
+        metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
+        metrics_dict['loss/ben'] = metrics_ary[METRICS_LOSS_NDX, benLabel_mask].mean()
+        metrics_dict['loss/mal'] = metrics_ary[METRICS_LOSS_NDX, malLabel_mask].mean()
+
+        metrics_dict['correct/all'] = (malCorrect_count + benCorrect_count) / metrics_ary.shape[1] * 100
+        metrics_dict['correct/ben'] = (benCorrect_count) / benLabel_count * 100
+        metrics_dict['correct/mal'] = (malCorrect_count) / malLabel_count * 100
+
+        precision = metrics_dict['pr/precision'] = truePos_count / (truePos_count + falsePos_count)
+        recall    = metrics_dict['pr/recall']    = truePos_count / (truePos_count + falseNeg_count)
+
+        metrics_dict['pr/f1_score'] = 2 * (precision * recall) / (precision + recall)
+
+        log.info(
+            ("E{} {:8} "
+                 + "{loss/all:.4f} loss, "
+                 + "{correct/all:-5.1f}% correct, "
+                 + "{pr/precision:.4f} precision, "
+                 + "{pr/recall:.4f} recall, "
+                 + "{pr/f1_score:.4f} f1 score"
+            ).format(
+                epoch_ndx,
+                mode_str,
+                **metrics_dict,
+            )
+        )
+        log.info(
+            ("E{} {:8} "
+                 + "{loss/ben:.4f} loss, "
+                 + "{correct/ben:-5.1f}% correct ({benCorrect_count:} of {benLabel_count:})"
+            ).format(
+                epoch_ndx,
+                mode_str + '_ben',
+                benCorrect_count=benCorrect_count,
+                benLabel_count=benLabel_count,
+                **metrics_dict,
+            )
+        )
+        log.info(
+            ("E{} {:8} "
+                 + "{loss/mal:.4f} loss, "
+                 + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
+            ).format(
+                epoch_ndx,
+                mode_str + '_mal',
+                malCorrect_count=malCorrect_count,
+                malLabel_count=malLabel_count,
+                **metrics_dict,
+            )
+        )
+        writer = getattr(self, mode_str + '_writer')
+
+        for key, value in metrics_dict.items():
+            writer.add_scalar(key, value, self.totalTrainingSamples_count)
+
+        writer.add_pr_curve(
+            'pr',
+            metrics_ary[METRICS_LABEL_NDX],
+            metrics_ary[METRICS_PRED_NDX],
+            self.totalTrainingSamples_count,
+        )
+
+        bins = [x/50.0 for x in range(51)]
+
+        benHist_mask = benLabel_mask & (metrics_ary[METRICS_PRED_NDX] > 0.01)
+        malHist_mask = malLabel_mask & (metrics_ary[METRICS_PRED_NDX] < 0.99)
+
+        if benHist_mask.any():
+            writer.add_histogram(
+                'is_ben',
+                metrics_ary[METRICS_PRED_NDX, benHist_mask],
+                self.totalTrainingSamples_count,
+                bins=bins,
+            )
+        if malHist_mask.any():
+            writer.add_histogram(
+                'is_mal',
+                metrics_ary[METRICS_PRED_NDX, malHist_mask],
+                self.totalTrainingSamples_count,
+                bins=bins,
+            )
+
+        # score = 1 \
+        #     + metrics_dict['pr/f1_score'] \
+        #     - metrics_dict['loss/mal'] * 0.01 \
+        #     - metrics_dict['loss/all'] * 0.0001
+        #
+        # return score
+
+    # def logModelMetrics(self, model):
+    #     writer = getattr(self, 'trn_writer')
+    #
+    #     model = getattr(model, 'module', model)
+    #
+    #     for name, param in model.named_parameters():
+    #         if param.requires_grad:
+    #             min_data = float(param.data.min())
+    #             max_data = float(param.data.max())
+    #             max_extent = max(abs(min_data), abs(max_data))
+    #
+    #             # bins = [x/50*max_extent for x in range(-50, 51)]
+    #
+    #             try:
+    #                 writer.add_histogram(
+    #                     name.rsplit('.', 1)[-1] + '/' + name,
+    #                     param.data.cpu().numpy(),
+    #                     # metrics_ary[METRICS_PRED_NDX, benHist_mask],
+    #                     self.totalTrainingSamples_count,
+    #                     # bins=bins,
+    #                 )
+    #             except Exception as e:
+    #                 log.error([min_data, max_data])
+    #                 raise
+
+
+if __name__ == '__main__':
+    sys.exit(LunaTrainingApp().main() or 0)

+ 87 - 0
p2ch11/vis.py

@@ -0,0 +1,87 @@
+import matplotlib
+matplotlib.use('nbagg')
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+from p2ch11.dsets import Ct, LunaDataset
+
+clim=(-1000.0, 1300)
+
+def findMalignantSamples(start_ndx=0, limit=10):
+    ds = LunaDataset(sortby_str='malignancy_size')
+
+    malignantSample_list = []
+    for sample_tup in ds.noduleInfo_list:
+        if sample_tup[0]:
+            print(len(malignantSample_list), sample_tup)
+            malignantSample_list.append(sample_tup)
+
+        if len(malignantSample_list) >= limit:
+            break
+
+    return malignantSample_list
+
+def showNodule(series_uid, batch_ndx=None, **kwargs):
+    ds = LunaDataset(series_uid=series_uid, sortby_str='malignancy_size', **kwargs)
+    malignant_list = [i for i, x in enumerate(ds.noduleInfo_list) if x[0]]
+
+    if batch_ndx is None:
+        if malignant_list:
+            batch_ndx = malignant_list[0]
+        else:
+            print("Warning: no malignant samples found; using first non-malignant sample.")
+            batch_ndx = 0
+
+    ct = Ct(series_uid)
+    # ct_tensor, malignant_tensor, series_uid, center_irc = ds[batch_ndx]
+    # malignant_tensor, diameter_mm, series_uid, center_irc, nodule_tensor = ds[batch_ndx]
+    nodule_tensor, malignant_tensor, series_uid, center_irc = ds[batch_ndx]
+    ct_ary = nodule_tensor[0].numpy()
+
+
+    fig = plt.figure(figsize=(15, 25))
+
+    group_list = [
+        #[0,1,2],
+        [3,4,5],
+        [6,7,8],
+        [9,10,11],
+        #[12,13,14],
+        #[15]
+    ]
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 1)
+    subplot.set_title('index {}'.format(int(center_irc.index)))
+    plt.imshow(ct.ary[int(center_irc.index)], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 2)
+    subplot.set_title('row {}'.format(int(center_irc.row)))
+    plt.imshow(ct.ary[:,int(center_irc.row)], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 3)
+    subplot.set_title('col {}'.format(int(center_irc.col)))
+    plt.imshow(ct.ary[:,:,int(center_irc.col)], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 4)
+    subplot.set_title('index {}'.format(int(center_irc.index)))
+    plt.imshow(ct_ary[ct_ary.shape[0]//2], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 5)
+    subplot.set_title('row {}'.format(int(center_irc.row)))
+    plt.imshow(ct_ary[:,ct_ary.shape[1]//2], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 6)
+    subplot.set_title('col {}'.format(int(center_irc.col)))
+    plt.imshow(ct_ary[:,:,ct_ary.shape[2]//2], clim=clim, cmap='gray')
+
+    for row, index_list in enumerate(group_list):
+        for col, index in enumerate(index_list):
+            subplot = fig.add_subplot(len(group_list) + 2, 3, row * 3 + col + 7)
+            subplot.set_title('slice {}'.format(index))
+            plt.imshow(ct_ary[index*2], clim=clim, cmap='gray')
+
+
+    print(series_uid, batch_ndx, bool(malignant_tensor[1]), malignant_list, ct.vxSize_xyz)
+
+    return ct_ary

+ 0 - 0
p2ch12/__init__.py


+ 378 - 0
p2ch12/diagnose.py

@@ -0,0 +1,378 @@
+import argparse
+import glob
+import os
+import sys
+
+import numpy as np
+import scipy.ndimage.measurements as measure
+import scipy.ndimage.morphology as morph
+
+import torch
+import torch.nn as nn
+import torch.optim
+
+from torch.utils.data import DataLoader
+
+from util.util import enumerateWithEstimate
+from .dsets import LunaDataset, Luna2dSegmentationDataset, getCt, getNoduleInfoList, NoduleInfoTuple
+from .model_seg import UNetWrapper
+from .model_cls import LunaModel, AlternateLunaModel
+
+from util.logconf import logging
+from util.util import xyz2irc, irc2xyz
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+
+class LunaDiagnoseApp(object):
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            log.debug(sys.argv)
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=4,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+
+        parser.add_argument('--series-uid',
+            help='Limit inference to this Series UID only.',
+            default=None,
+            type=str,
+        )
+
+        parser.add_argument('--include-train',
+            help="Include data that was in the training set. (default: test data only)",
+            action='store_true',
+            default=False,
+        )
+
+        parser.add_argument('--segmentation-path',
+            help="Path to the saved segmentation model",
+            nargs='?',
+            default=None,
+        )
+
+        parser.add_argument('--classification-path',
+            help="Path to the saved classification model",
+            nargs='?',
+            default=None,
+        )
+
+        parser.add_argument('--tb-prefix',
+            default='p2ch12',
+            help="Data prefix to use for Tensorboard run. Defaults to chapter.",
+        )
+
+        self.cli_args = parser.parse_args(sys_argv)
+        # self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
+
+        self.use_cuda = torch.cuda.is_available()
+        self.device = torch.device("cuda" if self.use_cuda else "cpu")
+
+        if not self.cli_args.segmentation_path:
+            self.cli_args.segmentation_path = self.initModelPath('seg')
+
+        if not self.cli_args.classification_path:
+            self.cli_args.classification_path = self.initModelPath('cls')
+
+        self.seg_model, self.cls_model = self.initModels()
+
+    def initModelPath(self, type_str):
+        local_path = os.path.join(
+            'data-unversioned',
+            'part2',
+            'models',
+            self.cli_args.tb_prefix,
+            type_str + '_{}_{}.{}.state'.format('*', '*', 'best'),
+        )
+
+        file_list = glob.glob(local_path)
+        if not file_list:
+            pretrained_path = os.path.join(
+                'data',
+                'part2',
+                'models',
+                type_str + '_{}_{}.{}.state'.format('*', '*', '*'),
+            )
+            file_list = glob.glob(pretrained_path)
+        else:
+            pretrained_path = None
+
+        file_list.sort()
+
+        try:
+            return file_list[-1]
+        except IndexError:
+            log.debug([local_path, pretrained_path, file_list])
+            raise
+
+    def initModels(self):
+        log.debug(self.cli_args.segmentation_path)
+        seg_dict = torch.load(self.cli_args.segmentation_path)
+
+        seg_model = UNetWrapper(in_channels=8, n_classes=1, depth=4, wf=3, padding=True, batch_norm=True, up_mode='upconv')
+        seg_model.load_state_dict(seg_dict['model_state'])
+        seg_model.eval()
+
+        log.debug(self.cli_args.classification_path)
+        cls_dict = torch.load(self.cli_args.classification_path)
+
+        cls_model = LunaModel()
+        # cls_model = AlternateLunaModel()
+        cls_model.load_state_dict(cls_dict['model_state'])
+        cls_model.eval()
+
+        if self.use_cuda:
+            if torch.cuda.device_count() > 1:
+                seg_model = nn.DataParallel(seg_model)
+                cls_model = nn.DataParallel(cls_model)
+
+            seg_model = seg_model.to(self.device)
+            cls_model = cls_model.to(self.device)
+
+        return seg_model, cls_model
+
+
+    def initSegmentationDl(self, series_uid):
+        seg_ds = Luna2dSegmentationDataset(
+                contextSlices_count=3,
+                series_uid=series_uid,
+                fullCt_bool=True,
+            )
+        seg_dl = DataLoader(
+            seg_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return seg_dl
+
+    def initClassificationDl(self, noduleInfo_list):
+        cls_ds = LunaDataset(
+                sortby_str='series_uid',
+                noduleInfo_list=noduleInfo_list,
+            )
+        cls_dl = DataLoader(
+            cls_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return cls_dl
+
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        test_ds = LunaDataset(
+            test_stride=10,
+            isTestSet_bool=True,
+        )
+        test_set = set(
+            noduleInfo_tup.series_uid
+            for noduleInfo_tup in test_ds.noduleInfo_list
+        )
+        malignant_set = set(
+            noduleInfo_tup.series_uid
+            for noduleInfo_tup in getNoduleInfoList()
+            if noduleInfo_tup.isMalignant_bool
+        )
+
+        if self.cli_args.series_uid:
+            series_set = set(self.cli_args.series_uid.split(','))
+        else:
+            series_set = set(
+                noduleInfo_tup.series_uid
+                for noduleInfo_tup in getNoduleInfoList()
+            )
+
+        train_list = sorted(series_set - test_set) if self.cli_args.include_train else []
+        test_list = sorted(series_set & test_set)
+
+
+        noduleInfo_list = []
+        series_iter = enumerateWithEstimate(
+            test_list + train_list,
+            "Series",
+        )
+        for _series_ndx, series_uid in series_iter:
+            ct, output_ary, _mask_ary, clean_ary = self.segmentCt(series_uid)
+
+            noduleInfo_list += self.clusterSegmentationOutput(
+                series_uid,
+                ct,
+                clean_ary,
+            )
+
+            # if _series_ndx > 10:
+            #     break
+
+
+        cls_dl = self.initClassificationDl(noduleInfo_list)
+
+        series2diagnosis_dict = {}
+        batch_iter = enumerateWithEstimate(
+            cls_dl,
+            "Cls all",
+            start_ndx=cls_dl.num_workers,
+        )
+        for batch_ndx, batch_tup in batch_iter:
+            input_tensor, _, series_list, center_list = batch_tup
+
+            input_devtensor = input_tensor.to(self.device)
+            with torch.no_grad():
+                _logits_devtensor, probability_devtensor = self.cls_model(input_devtensor)
+
+            classifications_list = zip(
+                series_list,
+                center_list,
+                probability_devtensor[:,1].to('cpu'),
+            )
+
+            for cls_tup in classifications_list:
+                series_uid, center_irc, probablity_tensor = cls_tup
+                probablity_float = probablity_tensor.item()
+
+                this_tup = (probablity_float, tuple(center_irc))
+                current_tup = series2diagnosis_dict.get(series_uid, this_tup)
+                try:
+                    assert np.all(np.isfinite(tuple(center_irc)))
+                    if this_tup > current_tup:
+                        log.debug([series_uid, this_tup])
+                    series2diagnosis_dict[series_uid] = max(this_tup, current_tup)
+                except:
+                    log.debug([(type(x), x) for x in this_tup] + [(type(x), x) for x in current_tup])
+                    raise
+
+                # self.logResults(
+                #     'Testing' if isTest_bool else 'Training',
+                #     [(series_uid, series2diagnosis_dict[series_uid])],
+                #     malignant_set,
+                # )
+
+        log.info('Training set:')
+        self.logResults('Training', train_list, series2diagnosis_dict, malignant_set)
+
+        log.info('Testing set:')
+        self.logResults('Testing', test_list, series2diagnosis_dict, malignant_set)
+
+    def segmentCt(self, series_uid):
+        with torch.no_grad():
+            ct = getCt(series_uid)
+
+            output_ary = np.zeros_like(ct.ary, dtype=np.float32)
+
+            seg_dl = self.initSegmentationDl(series_uid)
+            for batch_tup in seg_dl:
+                input_tensor = batch_tup[0]
+                ndx_list = batch_tup[6]
+
+                input_devtensor = input_tensor.to(self.device)
+                prediction_devtensor = self.seg_model(input_devtensor)
+
+                for i, sample_ndx in enumerate(ndx_list):
+                    output_ary[sample_ndx] = prediction_devtensor[i].cpu().numpy()
+
+            mask_ary = output_ary > 0.5
+            clean_ary = morph.binary_erosion(mask_ary, iterations=1)
+            clean_ary = morph.binary_dilation(clean_ary, iterations=2)
+
+        return ct, output_ary, mask_ary, clean_ary
+
+    def clusterSegmentationOutput(self, series_uid,  ct, clean_ary):
+        noduleLabel_ary, nodule_count = measure.label(clean_ary)
+        centerIrc_list = measure.center_of_mass(
+            ct.ary + 1001,
+            labels=noduleLabel_ary,
+            index=list(range(1, nodule_count+1)),
+        )
+
+        # n = 1298
+        # log.debug([
+        #     (noduleLabel_ary == n).sum(),
+        #     np.where(noduleLabel_ary == n),
+        #
+        #     ct.ary[noduleLabel_ary == n].sum(),
+        #     (ct.ary + 1000)[noduleLabel_ary == n].sum(),
+        # ])
+
+        if nodule_count < 2:
+            centerIrc_list = [centerIrc_list]
+
+        noduleInfo_list = []
+        for i, center_irc in enumerate(centerIrc_list):
+            center_xyz = irc2xyz(
+                center_irc,
+                ct.origin_xyz,
+                ct.vxSize_xyz,
+                ct.direction_tup,
+            )
+            assert np.all(np.isfinite(center_irc)), repr(['irc', center_irc, i, nodule_count])
+            assert np.all(np.isfinite(center_xyz)), repr(['xyz', center_xyz])
+            noduleInfo_tup = \
+                NoduleInfoTuple(False, 0.0, series_uid, center_xyz)
+            noduleInfo_list.append(noduleInfo_tup)
+
+        return noduleInfo_list
+
+    def logResults(self, mode_str, filtered_list, series2diagnosis_dict, malignant_set):
+        count_dict = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}
+        for series_uid in filtered_list:
+            probablity_float, center_irc = series2diagnosis_dict.get(series_uid, (0.0, None))
+            if center_irc is not None:
+                center_irc = tuple(int(x.item()) for x in center_irc)
+            malignant_bool = series_uid in malignant_set
+            prediction_bool = probablity_float > 0.5
+            correct_bool = malignant_bool == prediction_bool
+
+            if malignant_bool and prediction_bool:
+                count_dict['tp'] += 1
+            if not malignant_bool and not prediction_bool:
+                count_dict['tn'] += 1
+            if not malignant_bool and prediction_bool:
+                count_dict['fp'] += 1
+            if malignant_bool and not prediction_bool:
+                count_dict['fn'] += 1
+
+
+            log.info("{} {} Mal:{!r:5} Pred:{!r:5} Correct?:{!r:5} Value:{:.4f} {}".format(
+                mode_str,
+                series_uid,
+                malignant_bool,
+                prediction_bool,
+                correct_bool,
+                probablity_float,
+                center_irc,
+            ))
+
+        total_count = sum(count_dict.values())
+        percent_dict = {k: v / (total_count or 1) * 100 for k, v in count_dict.items()}
+
+        precision = percent_dict['p'] = count_dict['tp'] / ((count_dict['tp'] + count_dict['fp']) or 1)
+        recall    = percent_dict['r'] = count_dict['tp'] / ((count_dict['tp'] + count_dict['fn']) or 1)
+        percent_dict['f1'] = 2 * (precision * recall) / ((precision + recall) or 1)
+
+        log.info(mode_str + " tp:{tp:.1f}%, tn:{tn:.1f}%, fp:{fp:.1f}%, fn:{fn:.1f}%".format(
+            **percent_dict,
+        ))
+        log.info(mode_str + " precision:{p:.3f}, recall:{r:.3f}, F1:{f1:.3f}".format(
+            **percent_dict,
+        ))
+
+
+
+if __name__ == '__main__':
+    sys.exit(LunaDiagnoseApp().main() or 0)

+ 569 - 0
p2ch12/dsets.py

@@ -0,0 +1,569 @@
+import copy
+import csv
+import functools
+import glob
+import math
+import os
+import random
+
+from collections import namedtuple
+
+import SimpleITK as sitk
+
+import numpy as np
+import scipy.ndimage.morphology as morph
+
+import torch
+import torch.cuda
+from torch.utils.data import Dataset
+import torch.nn as nn
+import torch.nn.functional as F
+
+from util.disk import getCache
+from util.util import XyzTuple, xyz2irc
+from util.logconf import logging
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+raw_cache = getCache('part2ch12_raw')
+
+NoduleInfoTuple = namedtuple('NoduleInfoTuple', 'isMalignant_bool, diameter_mm, series_uid, center_xyz')
+MaskTuple = namedtuple('MaskTuple', 'air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask, ben_mask, mal_mask')
+
+@functools.lru_cache(1)
+def getNoduleInfoList(requireDataOnDisk_bool=True):
+    # We construct a set with all series_uids that are present on disk.
+    # This will let us use the data, even if we haven't downloaded all of
+    # the subsets yet.
+    mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
+    dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
+
+    diameter_dict = {}
+    with open('data/part2/luna/annotations.csv', "r") as f:
+        for row in list(csv.reader(f))[1:]:
+            series_uid = row[0]
+            annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
+            annotationDiameter_mm = float(row[4])
+
+            diameter_dict.setdefault(series_uid, []).append((annotationCenter_xyz, annotationDiameter_mm))
+
+    noduleInfo_list = []
+    with open('data/part2/luna/candidates.csv', "r") as f:
+        for row in list(csv.reader(f))[1:]:
+            series_uid = row[0]
+
+            if series_uid not in dataPresentOnDisk_set and requireDataOnDisk_bool:
+                continue
+
+            isMalignant_bool = bool(int(row[4]))
+            candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
+
+            candidateDiameter_mm = 0.0
+            for annotationCenter_xyz, annotationDiameter_mm in diameter_dict.get(series_uid, []):
+                for i in range(3):
+                    delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
+                    if delta_mm > annotationDiameter_mm / 4:
+                        break
+                else:
+                    candidateDiameter_mm = annotationDiameter_mm
+                    break
+
+            noduleInfo_list.append(NoduleInfoTuple(isMalignant_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))
+
+    noduleInfo_list.sort(reverse=True)
+    return noduleInfo_list
+
+class Ct(object):
+    def __init__(self, series_uid, buildMasks_bool=True):
+        mhd_path = glob.glob('data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
+
+        ct_mhd = sitk.ReadImage(mhd_path)
+        ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
+
+        # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
+        # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
+        # This gets rid of negative density stuff used to indicate out-of-FOV
+        ct_ary[ct_ary < -1000] = -1000
+
+        # This nukes any weird hotspots and clamps bone down
+        ct_ary[ct_ary > 1000] = 1000
+
+        self.series_uid = series_uid
+        self.ary = ct_ary
+
+        self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
+        self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
+        self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
+
+        noduleInfo_list = getNoduleInfoList()
+        self.benignInfo_list = [ni_tup
+                                for ni_tup in noduleInfo_list
+                                    if not ni_tup.isMalignant_bool
+                                        and ni_tup.series_uid == self.series_uid]
+        self.benign_mask = self.buildAnnotationMask(self.benignInfo_list)[0]
+        self.benign_indexes = sorted(set(self.benign_mask.nonzero()[0]))
+
+        self.malignantInfo_list = [ni_tup
+                                   for ni_tup in noduleInfo_list
+                                        if ni_tup.isMalignant_bool
+                                            and ni_tup.series_uid == self.series_uid]
+        self.malignant_mask = self.buildAnnotationMask(self.malignantInfo_list)[0]
+        self.malignant_indexes = sorted(set(self.malignant_mask.nonzero()[0]))
+
+    def buildAnnotationMask(self, noduleInfo_list, threshold_gcc = -500):
+        boundingBox_ary = np.zeros_like(self.ary, dtype=np.bool)
+
+        for noduleInfo_tup in noduleInfo_list:
+            center_irc = xyz2irc(
+                noduleInfo_tup.center_xyz,
+                self.origin_xyz,
+                self.vxSize_xyz,
+                self.direction_tup,
+            )
+            ci = int(center_irc.index)
+            cr = int(center_irc.row)
+            cc = int(center_irc.col)
+
+            index_radius = 2
+            try:
+                while self.ary[ci + index_radius, cr, cc] > threshold_gcc and \
+                        self.ary[ci - index_radius, cr, cc] > threshold_gcc:
+                    index_radius += 1
+            except IndexError:
+                index_radius -= 1
+
+            row_radius = 2
+            try:
+                while self.ary[ci, cr + row_radius, cc] > threshold_gcc and \
+                        self.ary[ci, cr - row_radius, cc] > threshold_gcc:
+                    row_radius += 1
+            except IndexError:
+                row_radius -= 1
+
+            col_radius = 2
+            try:
+                while self.ary[ci, cr, cc + col_radius] > threshold_gcc and \
+                        self.ary[ci, cr, cc - col_radius] > threshold_gcc:
+                    col_radius += 1
+            except IndexError:
+                col_radius -= 1
+
+            # assert index_radius > 0, repr([noduleInfo_tup.center_xyz, center_irc, self.ary[ci, cr, cc]])
+            # assert row_radius > 0
+            # assert col_radius > 0
+
+
+            slice_tup = (
+                slice(ci - index_radius, ci + index_radius + 1),
+                slice(cr - row_radius, cr + row_radius + 1),
+                slice(cc - col_radius, cc + row_radius + 1),
+            )
+            boundingBox_ary[slice_tup] = True
+
+        thresholded_ary = boundingBox_ary & (self.ary > threshold_gcc)
+        mask_ary = morph.binary_dilation(thresholded_ary, iterations=2)
+
+        return mask_ary, thresholded_ary, boundingBox_ary
+
+    def build2dLungMask(self, mask_ndx, threshold_gcc = -300):
+        dense_mask = self.ary[mask_ndx] > threshold_gcc
+        denoise_mask = morph.binary_closing(dense_mask, iterations=2)
+        tissue_mask = morph.binary_opening(denoise_mask, iterations=10)
+        body_mask = morph.binary_fill_holes(tissue_mask)
+        air_mask = morph.binary_fill_holes(body_mask & ~tissue_mask)
+
+        lung_mask = morph.binary_dilation(air_mask, iterations=2)
+
+        ben_mask = denoise_mask & air_mask
+        ben_mask = morph.binary_dilation(ben_mask, iterations=2)
+        ben_mask &= ~self.malignant_mask[mask_ndx]
+
+        mal_mask = self.malignant_mask[mask_ndx]
+
+        return MaskTuple(
+            air_mask,
+            lung_mask,
+            dense_mask,
+            denoise_mask,
+            tissue_mask,
+            body_mask,
+            ben_mask,
+            mal_mask,
+        )
+
+    def build3dLungMask(self):
+        air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask, ben_mask, mal_mask = mask_list = \
+            [np.zeros_like(self.ary, dtype=np.bool) for _ in range(7)]
+
+        for mask_ndx in range(self.ary.shape[0]):
+            for i, mask_ary in enumerate(self.build2dLungMask(mask_ndx)):
+                mask_list[i][mask_ndx] = mask_ary
+
+        return MaskTuple(air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask, ben_mask, mal_mask)
+
+
+    def getRawNodule(self, center_xyz, width_irc):
+        center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
+
+        slice_list = []
+        for axis, center_val in enumerate(center_irc):
+            try:
+                start_ndx = int(round(center_val - width_irc[axis]/2))
+            except:
+                log.debug([center_val, width_irc, center_xyz, center_irc])
+                raise
+            end_ndx = int(start_ndx + width_irc[axis])
+
+            assert center_val >= 0 and center_val < self.ary.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
+
+            if start_ndx < 0:
+                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
+                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
+                start_ndx = 0
+                end_ndx = int(width_irc[axis])
+
+            if end_ndx > self.ary.shape[axis]:
+                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
+                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
+                end_ndx = self.ary.shape[axis]
+                start_ndx = int(self.ary.shape[axis] - width_irc[axis])
+
+            slice_list.append(slice(start_ndx, end_ndx))
+
+        ct_chunk = self.ary[tuple(slice_list)]
+
+        return ct_chunk, center_irc
+
+ctCache_depth = 5
+@functools.lru_cache(ctCache_depth, typed=True)
+def getCt(series_uid):
+    return Ct(series_uid)
+
+@raw_cache.memoize(typed=True)
+def getCtRawNodule(series_uid, center_xyz, width_irc):
+    ct = getCt(series_uid)
+    ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
+    return ct_chunk, center_irc
+
+@raw_cache.memoize(typed=True)
+def getCtSampleSize(series_uid):
+    ct = Ct(series_uid, buildMasks_bool=False)
+    return len(ct.benign_indexes)
+
+def getCtAugmentedNodule(
+        augmentation_dict,
+        series_uid, center_xyz, width_irc,
+        use_cache=True):
+    if use_cache:
+        ct_chunk, center_irc = getCtRawNodule(series_uid, center_xyz, width_irc)
+    else:
+        ct = getCt(series_uid)
+        ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
+
+    ct_tensor = torch.tensor(ct_chunk).unsqueeze(0).unsqueeze(0).to(torch.float32)
+
+    transform_tensor = torch.eye(4).to(torch.float64)
+
+    for i in range(3):
+        if 'flip' in augmentation_dict:
+            if random.random() > 0.5:
+                transform_tensor[i,i] *= -1
+
+        if 'offset' in augmentation_dict:
+            offset_float = augmentation_dict['offset']
+            random_float = (random.random() * 2 - 1)
+            transform_tensor[3,i] = offset_float * random_float
+
+        if 'scale' in augmentation_dict:
+            scale_float = augmentation_dict['scale']
+            random_float = (random.random() * 2 - 1)
+            transform_tensor[i,i] *= 1.0 + scale_float * random_float
+
+    if 'rotate' in augmentation_dict:
+        angle_rad = random.random() * math.pi * 2
+        s = math.sin(angle_rad)
+        c = math.cos(angle_rad)
+
+        rotation_tensor = torch.tensor([
+            [c, -s, 0, 0],
+            [s, c, 0, 0],
+            [0, 0, 1, 0],
+            [0, 0, 0, 1],
+        ], dtype=torch.float64)
+
+        transform_tensor @= rotation_tensor
+
+    affine_tensor = F.affine_grid(
+            transform_tensor[:3].unsqueeze(0).to(torch.float32),
+            ct_tensor.size(),
+        )
+
+    augmented_chunk = F.grid_sample(
+            ct_tensor,
+            affine_tensor,
+            padding_mode='border'
+        ).to('cpu')
+
+    if 'noise' in augmentation_dict:
+        noise_tensor = torch.randn_like(augmented_chunk)
+        noise_tensor *= augmentation_dict['noise']
+
+        augmented_chunk += noise_tensor
+
+    return augmented_chunk[0], center_irc
+
+
+class LunaDataset(Dataset):
+    def __init__(self,
+                 test_stride=0,
+                 isTestSet_bool=None,
+                 series_uid=None,
+                 sortby_str='random',
+                 ratio_int=0,
+                 augmentation_dict=None,
+                 noduleInfo_list=None,
+            ):
+        self.ratio_int = ratio_int
+        self.augmentation_dict = augmentation_dict
+
+        if noduleInfo_list:
+            self.noduleInfo_list = copy.copy(noduleInfo_list)
+            self.use_cache = False
+        else:
+            self.noduleInfo_list = copy.copy(getNoduleInfoList())
+            self.use_cache = True
+
+        if series_uid:
+            self.series_list = [series_uid]
+        else:
+            self.series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
+
+        if isTestSet_bool:
+            assert test_stride > 0, test_stride
+            self.series_list = self.series_list[::test_stride]
+            assert self.series_list
+        elif test_stride > 0:
+            del self.series_list[::test_stride]
+            assert self.series_list
+
+        series_set = set(self.series_list)
+        self.noduleInfo_list = [x for x in self.noduleInfo_list if x.series_uid in series_set]
+
+        if sortby_str == 'random':
+            random.shuffle(self.noduleInfo_list)
+        elif sortby_str == 'series_uid':
+            self.noduleInfo_list.sort(key=lambda x: (x[2], x[3])) # sorting by series_uid, center_xyz)
+        elif sortby_str == 'malignancy_size':
+            pass
+        else:
+            raise Exception("Unknown sort: " + repr(sortby_str))
+
+        self.benign_list = [nt for nt in self.noduleInfo_list if not nt.isMalignant_bool]
+        self.malignant_list = [nt for nt in self.noduleInfo_list if nt.isMalignant_bool]
+
+        log.info("{!r}: {} {} samples, {} ben, {} mal, {} ratio".format(
+            self,
+            len(self.noduleInfo_list),
+            "testing" if isTestSet_bool else "training",
+            len(self.benign_list),
+            len(self.malignant_list),
+            '{}:1'.format(self.ratio_int) if self.ratio_int else 'unbalanced'
+        ))
+
+    def shuffleSamples(self):
+        if self.ratio_int:
+            random.shuffle(self.benign_list)
+            random.shuffle(self.malignant_list)
+
+    def __len__(self):
+        if self.ratio_int:
+            # return 20000
+            return 200000
+        else:
+            return len(self.noduleInfo_list)
+
+    def __getitem__(self, ndx):
+        if self.ratio_int:
+            malignant_ndx = ndx // (self.ratio_int + 1)
+
+            if ndx % (self.ratio_int + 1):
+                benign_ndx = ndx - 1 - malignant_ndx
+                nodule_tup = self.benign_list[benign_ndx % len(self.benign_list)]
+            else:
+                nodule_tup = self.malignant_list[malignant_ndx % len(self.malignant_list)]
+        else:
+            nodule_tup = self.noduleInfo_list[ndx]
+
+        width_irc = (24, 48, 48)
+
+        if self.augmentation_dict:
+            nodule_t, center_irc = getCtAugmentedNodule(
+                self.augmentation_dict,
+                nodule_tup.series_uid,
+                nodule_tup.center_xyz,
+                width_irc,
+                self.use_cache,
+            )
+        elif self.use_cache:
+            nodule_ary, center_irc = getCtRawNodule(
+                nodule_tup.series_uid,
+                nodule_tup.center_xyz,
+                width_irc,
+            )
+            nodule_t = torch.from_numpy(nodule_ary).to(torch.float32)
+            nodule_t = nodule_t.unsqueeze(0)
+        else:
+            ct = getCt(nodule_tup.series_uid)
+            nodule_ary, center_irc = ct.getRawNodule(
+                nodule_tup.center_xyz,
+                width_irc,
+            )
+            nodule_t = torch.from_numpy(nodule_ary).to(torch.float32)
+            nodule_t = nodule_t.unsqueeze(0)
+
+        malignant_tensor = torch.tensor([
+                not nodule_tup.isMalignant_bool,
+                nodule_tup.isMalignant_bool
+            ],
+            dtype=torch.long,
+        )
+
+        # log.debug([type(center_irc), center_irc])
+
+        return nodule_t, malignant_tensor, nodule_tup.series_uid, torch.tensor(center_irc)
+
+
+
+
+class Luna2dSegmentationDataset(Dataset):
+    def __init__(self,
+                 test_stride=0,
+                 isTestSet_bool=None,
+                 series_uid=None,
+                 contextSlices_count=2,
+                 augmentation_dict=None,
+                 fullCt_bool=False,
+            ):
+        self.contextSlices_count = contextSlices_count
+        self.augmentation_dict = augmentation_dict
+
+        if series_uid:
+            self.series_list = [series_uid]
+        else:
+            self.series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
+
+        if isTestSet_bool:
+            assert test_stride > 0, test_stride
+            self.series_list = self.series_list[::test_stride]
+            assert self.series_list
+        elif test_stride > 0:
+            del self.series_list[::test_stride]
+            assert self.series_list
+
+        self.sample_list = []
+        for series_uid in self.series_list:
+            if fullCt_bool:
+                self.sample_list.extend([(series_uid, ct_ndx) for ct_ndx in range(getCt(series_uid).ary.shape[0])])
+            else:
+                self.sample_list.extend([(series_uid, ct_ndx) for ct_ndx in range(getCtSampleSize(series_uid))])
+
+        log.info("{!r}: {} {} series, {} slices".format(
+            self,
+            len(self.series_list),
+            {None: 'general', True: 'testing', False: 'training'}[isTestSet_bool],
+            len(self.sample_list),
+        ))
+
+    def __len__(self):
+        return len(self.sample_list) #// 100
+
+    def __getitem__(self, ndx):
+        if isinstance(ndx, int):
+            series_uid, sample_ndx = self.sample_list[ndx % len(self.sample_list)]
+            ct = getCt(series_uid)
+            ct_ndx = self.sample_list[sample_ndx][1]
+            useAugmentation_bool = False
+        else:
+            series_uid, ct_ndx, useAugmentation_bool = ndx
+            ct = getCt(series_uid)
+
+        ct_tensor = torch.zeros((self.contextSlices_count * 2 + 1 + 1, 512, 512))
+
+        start_ndx = ct_ndx - self.contextSlices_count
+        end_ndx = ct_ndx + self.contextSlices_count + 1
+        for i, context_ndx in enumerate(range(start_ndx, end_ndx)):
+            context_ndx = max(context_ndx, 0)
+            context_ndx = min(context_ndx, ct.ary.shape[0] - 1)
+
+            ct_tensor[i] = torch.from_numpy(ct.ary[context_ndx].astype(np.float32))
+        ct_tensor /= 1000
+
+        mask_tup = ct.build2dLungMask(ct_ndx)
+
+        ct_tensor[-1] = torch.from_numpy(mask_tup.body_mask.astype(np.float32))
+
+        nodule_tensor = torch.from_numpy(
+            (mask_tup.mal_mask | mask_tup.ben_mask).astype(np.float32)
+        ).unsqueeze(0)
+        ben_tensor = torch.from_numpy(mask_tup.ben_mask.astype(np.float32))
+        mal_tensor = torch.from_numpy(mask_tup.mal_mask.astype(np.float32))
+        label_int = mal_tensor.max() + ben_tensor.max() * 2
+
+        if self.augmentation_dict and useAugmentation_bool:
+            if 'rotate' in self.augmentation_dict:
+                if random.random() > 0.5:
+                    ct_tensor = ct_tensor.rot90(1, [1, 2])
+                    nodule_tensor = nodule_tensor.rot90(1, [1, 2])
+
+            if 'flip' in self.augmentation_dict:
+                dims = [d+1 for d in range(2) if random.random() > 0.5]
+
+                if dims:
+                    ct_tensor = ct_tensor.flip(dims)
+                    nodule_tensor = nodule_tensor.flip(dims)
+
+            if 'noise' in self.augmentation_dict:
+                noise_tensor = torch.randn_like(ct_tensor)
+                noise_tensor *= self.augmentation_dict['noise']
+
+                ct_tensor += noise_tensor
+        return ct_tensor, nodule_tensor, label_int, ben_tensor, mal_tensor, ct.series_uid, ct_ndx
+
+
+class TrainingLuna2dSegmentationDataset(Luna2dSegmentationDataset):
+    def __init__(self, *args, batch_size=80, **kwargs):
+        self.needsShuffle_bool = True
+        self.batch_size = batch_size
+        # self.rotate_frac = 0.5 * len(self.series_list) / len(self)
+        super().__init__(*args, **kwargs)
+
+    def __len__(self):
+        return 50000
+
+    def __getitem__(self, ndx):
+        if self.needsShuffle_bool:
+            random.shuffle(self.series_list)
+            self.needsShuffle_bool = False
+
+        if isinstance(ndx, int):
+            if ndx % self.batch_size == 0:
+                self.series_list.append(self.series_list.pop(0))
+
+            series_uid = self.series_list[ndx % ctCache_depth]
+            ct = getCt(series_uid)
+
+            if ndx % 3 == 0:
+                ct_ndx = random.choice(ct.malignant_indexes or ct.benign_indexes)
+            elif ndx % 3 == 1:
+                ct_ndx = random.choice(ct.benign_indexes)
+            elif ndx % 3 == 2:
+                ct_ndx = random.choice(list(range(ct.ary.shape[0])))
+
+            useAugmentation_bool = True
+        else:
+            series_uid, ct_ndx, useAugmentation_bool = ndx
+
+        return super().__getitem__((series_uid, ct_ndx, useAugmentation_bool))

+ 41 - 0
p2ch12/model.py

@@ -0,0 +1,41 @@
+import math
+
+from torch import nn as nn
+
+from util.logconf import logging
+from util.unet import UNet
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+class UNetWrapper(nn.Module):
+    def __init__(self, **kwargs):
+        super().__init__()
+
+        self.batchnorm = nn.BatchNorm2d(kwargs['in_channels'])
+        self.unet = UNet(**kwargs)
+        self.final = nn.Sigmoid()
+
+        for m in self.modules():
+            if type(m) in {
+                nn.Conv2d,
+                nn.Conv3d,
+                nn.ConvTranspose2d,
+                nn.ConvTranspose3d,
+                nn.Linear,
+            }:
+                nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='leaky_relu', a=0)
+                if m.bias is not None:
+                    fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
+                    bound = 1 / math.sqrt(fan_out)
+                    nn.init.normal_(m.bias, -bound, bound)
+
+
+    def forward(self, input):
+        bn_output = self.batchnorm(input)
+        un_output = self.unet(bn_output)
+        fn_output = self.final(un_output)
+
+        return fn_output

+ 109 - 0
p2ch12/model_cls.py

@@ -0,0 +1,109 @@
+import math
+
+import torch.nn as nn
+
+from util.logconf import logging
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+
+class LunaModel(nn.Module):
+    def __init__(self, layer_count=4, in_channels=1, conv_channels=8):
+        super().__init__()
+
+        self.input_batchnorm = nn.BatchNorm3d(1)
+
+        layer_list = []
+        for layer_ndx in range(layer_count):
+            layer_list += [
+                nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True),
+                nn.ReLU(inplace=True),
+                nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=True),
+                nn.ReLU(inplace=True),
+                nn.MaxPool3d(2, 2),
+            ]
+
+            in_channels = conv_channels
+            conv_channels *= 2
+
+        self.convAndPool_seq = nn.Sequential(*layer_list)
+        self.fullyConnected_layer = nn.Linear(576, 2)
+        self.final = nn.Softmax(dim=1)
+
+        self._init_weights()
+
+    def _init_weights(self):
+        # see also https://github.com/pytorch/pytorch/issues/18182
+        for m in self.modules():
+            if type(m) in {
+                nn.Conv2d,
+                nn.Conv3d,
+                nn.ConvTranspose2d,
+                nn.ConvTranspose3d,
+                nn.Linear,
+            }:
+                # log.debug(m)
+                # nn.init.kaiming_normal_(m.weight.data, mode='fan_out', a=0)
+                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out', nonlinearity='relu')
+                if m.bias is not None:
+                    fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
+                    bound = 1 / math.sqrt(fan_out)
+                    nn.init.normal_(m.bias, -bound, bound)
+
+    def forward(self, input_batch):
+        bn_output = self.input_batchnorm(input_batch)
+        conv_output = self.convAndPool_seq(bn_output)
+        conv_flat = conv_output.view(conv_output.size(0), -1)
+        classifier_output = self.fullyConnected_layer(conv_flat)
+
+        return classifier_output, self.final(classifier_output)
+
+
+class AlternateLunaModel(nn.Module):
+    def __init__(self, layer_count=4, in_channels=1, conv_channels=64):
+        super().__init__()
+
+        layer_list = [nn.BatchNorm3d(1)]
+        for layer_ndx in range(layer_count):
+            layer_list += [
+                nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True),
+                nn.ReLU(inplace=True),
+                nn.Conv3d(conv_channels, conv_channels // 2, kernel_size=3, padding=1, bias=True),
+                nn.ReLU(inplace=True),
+                nn.MaxPool3d(2, 2),
+            ]
+
+            conv_channels //= 2
+            in_channels = conv_channels
+
+        self.convAndPool_seq = nn.Sequential(*layer_list)
+        self.fullyConnected_layer = nn.Linear(36, 2)
+        self.final = nn.Softmax(dim=1)
+
+        # see also https://github.com/pytorch/pytorch/issues/18182
+        for m in self.modules():
+            if type(m) in {
+                nn.Conv2d,
+                nn.Conv3d,
+                nn.ConvTranspose2d,
+                nn.ConvTranspose3d,
+                nn.Linear,
+            }:
+                # log.debug(m)
+                # nn.init.kaiming_normal_(m.weight.data, mode='fan_out', a=0)
+                nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu')
+                if m.bias is not None:
+                    fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
+                    bound = 1 / math.sqrt(fan_out)
+                    nn.init.normal_(m.bias, -bound, bound)
+
+    def forward(self, input_batch):
+        conv_output = self.convAndPool_seq(input_batch)
+        conv_flat = conv_output.view(conv_output.size(0), -1)
+        classifier_output = self.fullyConnected_layer(conv_flat)
+
+        return classifier_output, self.final(classifier_output)
+

+ 46 - 0
p2ch12/model_seg.py

@@ -0,0 +1,46 @@
+import math
+
+from torch import nn as nn
+
+from util.logconf import logging
+from util.unet import UNet
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+class UNetWrapper(nn.Module):
+    def __init__(self, **kwargs):
+        super().__init__()
+
+        self.input_batchnorm = nn.BatchNorm2d(kwargs['in_channels'])
+        self.unet = UNet(**kwargs)
+        self.final = nn.Sigmoid()
+
+        self._init_weights()
+
+    def _init_weights(self):
+        init_set = {
+            nn.Conv2d,
+            nn.Conv3d,
+            nn.ConvTranspose2d,
+            nn.ConvTranspose3d,
+            nn.Linear,
+        }
+        for m in self.modules():
+            if type(m) in init_set:
+                nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu', a=0)
+                if m.bias is not None:
+                    fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
+                    bound = 1 / math.sqrt(fan_out)
+                    nn.init.normal_(m.bias, -bound, bound)
+
+
+    def forward(self, input_batch):
+        bn_output = self.input_batchnorm(input_batch)
+        un_output = self.unet(bn_output)
+        fn_output = self.final(un_output)
+
+        return fn_output
+

+ 72 - 0
p2ch12/prepcache.py

@@ -0,0 +1,72 @@
+import argparse
+import sys
+
+import numpy as np
+
+import torch.nn as nn
+from torch.autograd import Variable
+from torch.optim import SGD
+from torch.utils.data import DataLoader
+
+from util.util import enumerateWithEstimate
+from .dsets import LunaDataset, getCtSampleSize
+from util.logconf import logging
+# from .model import LunaModel
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+log.setLevel(logging.INFO)
+# log.setLevel(logging.DEBUG)
+
+
+class LunaPrepCacheApp(object):
+    @classmethod
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=1024,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+        # parser.add_argument('--scaled',
+        #     help="Scale the CT chunks to square voxels.",
+        #     default=False,
+        #     action='store_true',
+        # )
+
+        self.cli_args = parser.parse_args(sys_argv)
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        self.prep_dl = DataLoader(
+            LunaDataset(
+                sortby_str='series_uid',
+            ),
+            batch_size=self.cli_args.batch_size,
+            num_workers=self.cli_args.num_workers,
+        )
+
+        batch_iter = enumerateWithEstimate(
+            self.prep_dl,
+            "Stuffing cache",
+            start_ndx=self.prep_dl.num_workers,
+        )
+        for batch_ndx, batch_tup in batch_iter:
+            _nodule_tensor, _malignant_tensor, series_list, _center_list = batch_tup
+            for series_uid in sorted(set(series_list)):
+                getCtSampleSize(series_uid)
+            # input_tensor, label_tensor, _series_list, _start_list = batch_tup
+
+
+
+if __name__ == '__main__':
+    sys.exit(LunaPrepCacheApp().main() or 0)

+ 92 - 0
p2ch12/screencts.py

@@ -0,0 +1,92 @@
+import argparse
+import sys
+
+import numpy as np
+
+import torch.nn as nn
+from torch.autograd import Variable
+from torch.optim import SGD
+from torch.utils.data import Dataset, DataLoader
+
+from util.util import enumerateWithEstimate, prhist
+from .dsets import getNoduleInfoList, getCtSize, getCt
+from util.logconf import logging
+# from .model import LunaModel
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+log.setLevel(logging.INFO)
+# log.setLevel(logging.DEBUG)
+
+
+class LunaScreenCtDataset(Dataset):
+    def __init__(self):
+        self.series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
+
+    def __len__(self):
+        return len(self.series_list)
+
+    def __getitem__(self, ndx):
+        series_uid = self.series_list[ndx]
+        ct = getCt(series_uid)
+        mid_ndx = ct.ary.shape[0] // 2
+
+        air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask, altben_mask = ct.build2dLungMask(mid_ndx)
+
+        return series_uid, float(dense_mask.sum() / denoise_mask.sum())
+
+
+class LunaScreenCtApp(object):
+    @classmethod
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=4,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+        # parser.add_argument('--scaled',
+        #     help="Scale the CT chunks to square voxels.",
+        #     default=False,
+        #     action='store_true',
+        # )
+
+        self.cli_args = parser.parse_args(sys_argv)
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        self.prep_dl = DataLoader(
+            LunaScreenCtDataset(),
+            batch_size=self.cli_args.batch_size,
+            num_workers=self.cli_args.num_workers,
+        )
+
+        series2ratio_dict = {}
+
+        batch_iter = enumerateWithEstimate(
+            self.prep_dl,
+            "Screening CTs",
+            start_ndx=self.prep_dl.num_workers,
+        )
+        for batch_ndx, batch_tup in batch_iter:
+            series_list, ratio_list = batch_tup
+            for series_uid, ratio_float in zip(series_list, ratio_list):
+                series2ratio_dict[series_uid] = ratio_float
+            # break
+
+        prhist(list(series2ratio_dict.values()))
+
+
+
+
+if __name__ == '__main__':
+    sys.exit(LunaScreenCtApp().main() or 0)

+ 454 - 0
p2ch12/train_cls.py

@@ -0,0 +1,454 @@
+import argparse
+import datetime
+import os
+import sys
+
+import numpy as np
+
+from tensorboardX import SummaryWriter
+
+import torch
+import torch.nn as nn
+from torch.optim import SGD, Adam
+from torch.utils.data import DataLoader
+
+from util.util import enumerateWithEstimate
+from .dsets import LunaDataset
+from .model_cls import LunaModel
+
+from util.logconf import logging
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+# Used for computeBatchLoss and logMetrics to index into metrics_tensor/metrics_ary
+METRICS_LABEL_NDX=0
+METRICS_PRED_NDX=1
+METRICS_LOSS_NDX=2
+METRICS_SIZE = 3
+
+class LunaTrainingApp(object):
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=32,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+        parser.add_argument('--epochs',
+            help='Number of epochs to train for',
+            default=1,
+            type=int,
+        )
+        parser.add_argument('--balanced',
+            help="Balance the training data to half benign, half malignant.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augmented',
+            help="Augment the training data.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-flip',
+            help="Augment the training data by randomly flipping the data left-right, up-down, and front-back.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-offset',
+            help="Augment the training data by randomly offsetting the data slightly along the X and Y axes.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-scale',
+            help="Augment the training data by randomly increasing or decreasing the size of the nodule.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-rotate',
+            help="Augment the training data by randomly rotating the data around the head-foot axis.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-noise',
+            help="Augment the training data by randomly adding noise to the data.",
+            action='store_true',
+            default=False,
+        )
+
+        parser.add_argument('--tb-prefix',
+            default='p2ch12',
+            help="Data prefix to use for Tensorboard run. Defaults to chapter.",
+        )
+        parser.add_argument('comment',
+            help="Comment suffix for Tensorboard run.",
+            nargs='?',
+            default='none',
+        )
+
+        self.cli_args = parser.parse_args(sys_argv)
+        self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
+
+        self.totalTrainingSamples_count = 0
+        self.trn_writer = None
+        self.tst_writer = None
+
+        self.augmentation_dict = {}
+        if self.cli_args.augmented or self.cli_args.augment_flip:
+            self.augmentation_dict['flip'] = True
+        if self.cli_args.augmented or self.cli_args.augment_offset:
+            self.augmentation_dict['offset'] = 0.1
+        if self.cli_args.augmented or self.cli_args.augment_scale:
+            self.augmentation_dict['scale'] = 0.2
+        if self.cli_args.augmented or self.cli_args.augment_rotate:
+            self.augmentation_dict['rotate'] = True
+        if self.cli_args.augmented or self.cli_args.augment_noise:
+            self.augmentation_dict['noise'] = 25.0
+
+        self.use_cuda = torch.cuda.is_available()
+        self.device = torch.device("cuda" if self.use_cuda else "cpu")
+
+        self.model = self.initModel()
+        self.optimizer = self.initOptimizer()
+
+
+    def initModel(self):
+        model = LunaModel()
+        if self.use_cuda:
+            if torch.cuda.device_count() > 1:
+                model = nn.DataParallel(model)
+            model = model.to(self.device)
+        return model
+
+    def initOptimizer(self):
+        return SGD(self.model.parameters(), lr=0.001, momentum=0.99)
+#         return Adam(self.model.parameters())
+
+    def initTrainDl(self):
+        train_ds = LunaDataset(
+            test_stride=10,
+            isTestSet_bool=False,
+            ratio_int=int(self.cli_args.balanced),
+            augmentation_dict=self.augmentation_dict,
+        )
+
+        train_dl = DataLoader(
+            train_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return train_dl
+
+    def initTestDl(self):
+        test_ds = LunaDataset(
+            test_stride=10,
+            isTestSet_bool=True,
+        )
+
+        test_dl = DataLoader(
+            test_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return test_dl
+
+    def initTensorboardWriters(self):
+        if self.trn_writer is None:
+            log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
+
+            self.trn_writer = SummaryWriter(log_dir=log_dir + '_trn_cls_' + self.cli_args.comment)
+            self.tst_writer = SummaryWriter(log_dir=log_dir + '_tst_cls_' + self.cli_args.comment)
+# eng::tb_writer[]
+
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        train_dl = self.initTrainDl()
+        test_dl = self.initTestDl()
+
+        best_score = 0.0
+
+        for epoch_ndx in range(1, self.cli_args.epochs + 1):
+
+            log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
+                epoch_ndx,
+                self.cli_args.epochs,
+                len(train_dl),
+                len(test_dl),
+                self.cli_args.batch_size,
+                (torch.cuda.device_count() if self.use_cuda else 1),
+            ))
+
+            trnMetrics_tensor = self.doTraining(epoch_ndx, train_dl)
+            self.logMetrics(epoch_ndx, 'trn', trnMetrics_tensor)
+
+            tstMetrics_tensor = self.doTesting(epoch_ndx, test_dl)
+            score = self.logMetrics(epoch_ndx, 'tst', tstMetrics_tensor)
+            best_score = max(score, best_score)
+
+            self.saveModel('cls', epoch_ndx, score == best_score)
+
+        if hasattr(self, 'trn_writer'):
+            self.trn_writer.close()
+            self.tst_writer.close()
+
+
+    def doTraining(self, epoch_ndx, train_dl):
+        self.model.train()
+        train_dl.dataset.shuffleSamples()
+        trainingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(train_dl.dataset)).to(self.device)
+        batch_iter = enumerateWithEstimate(
+            train_dl,
+            "E{} Training".format(epoch_ndx),
+            start_ndx=train_dl.num_workers,
+        )
+        for batch_ndx, batch_tup in batch_iter:
+            self.optimizer.zero_grad()
+
+            loss_var = self.computeBatchLoss(
+                batch_ndx,
+                batch_tup,
+                train_dl.batch_size,
+                trainingMetrics_devtensor
+            )
+
+            loss_var.backward()
+            self.optimizer.step()
+            del loss_var
+
+        self.totalTrainingSamples_count += trainingMetrics_devtensor.size(1)
+
+        return trainingMetrics_devtensor.to('cpu')
+
+
+    def doTesting(self, epoch_ndx, test_dl):
+        with torch.no_grad():
+            self.model.eval()
+            testingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset)).to(self.device)
+            batch_iter = enumerateWithEstimate(
+                test_dl,
+                "E{} Testing ".format(epoch_ndx),
+                start_ndx=test_dl.num_workers,
+            )
+            for batch_ndx, batch_tup in batch_iter:
+                self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_devtensor)
+
+        return testingMetrics_devtensor.to('cpu')
+
+
+
+    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_devtensor):
+        input_tensor, label_tensor, _series_list, _center_list = batch_tup
+
+        input_devtensor = input_tensor.to(self.device, non_blocking=True)
+        label_devtensor = label_tensor.to(self.device, non_blocking=True)
+
+        logits_devtensor, probability_devtensor = self.model(input_devtensor)
+
+        # log.debug(['input', input_devtensor.min().item(), input_devtensor.max().item()])
+        # log.debug(['label', label_devtensor.min().item(), label_devtensor.max().item()])
+        # log.debug(['logits', logits_devtensor.min().item(), logits_devtensor.max().item()])
+        # log.debug(['probability', probability_devtensor.min().item(), probability_devtensor.max().item()])
+
+        loss_func = nn.CrossEntropyLoss(reduction='none')
+        loss_devtensor = loss_func(logits_devtensor, label_devtensor[:,1])
+
+        # log.debug(['loss', loss_devtensor.min().item(), loss_devtensor.max().item()])
+
+        start_ndx = batch_ndx * batch_size
+        end_ndx = start_ndx + label_tensor.size(0)
+
+        metrics_devtensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_devtensor[:,1]
+        metrics_devtensor[METRICS_PRED_NDX, start_ndx:end_ndx] = probability_devtensor[:,1]
+        metrics_devtensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_devtensor
+
+        return loss_devtensor.mean()
+
+
+    def logMetrics(
+            self,
+            epoch_ndx,
+            mode_str,
+            metrics_tensor,
+    ):
+        self.initTensorboardWriters()
+        log.info("E{} {}".format(
+            epoch_ndx,
+            type(self).__name__,
+        ))
+
+        metrics_ary = metrics_tensor.cpu().detach().numpy()
+#         assert np.isfinite(metrics_ary).all()
+
+        benLabel_mask = metrics_ary[METRICS_LABEL_NDX] <= 0.5
+        benPred_mask = metrics_ary[METRICS_PRED_NDX] <= 0.5
+
+        malLabel_mask = ~benLabel_mask
+        malPred_mask = ~benPred_mask
+
+        benLabel_count = benLabel_mask.sum()
+        malLabel_count = malLabel_mask.sum()
+
+        trueNeg_count = benCorrect_count = (benLabel_mask & benPred_mask).sum()
+        truePos_count = malCorrect_count = (malLabel_mask & malPred_mask).sum()
+
+        falsePos_count = benLabel_count - benCorrect_count
+        falseNeg_count = malLabel_count - malCorrect_count
+
+        metrics_dict = {}
+        metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
+        metrics_dict['loss/ben'] = metrics_ary[METRICS_LOSS_NDX, benLabel_mask].mean()
+        metrics_dict['loss/mal'] = metrics_ary[METRICS_LOSS_NDX, malLabel_mask].mean()
+
+        metrics_dict['correct/all'] = (malCorrect_count + benCorrect_count) / metrics_ary.shape[1] * 100
+        metrics_dict['correct/ben'] = (benCorrect_count) / benLabel_count * 100
+        metrics_dict['correct/mal'] = (malCorrect_count) / malLabel_count * 100
+
+        precision = metrics_dict['pr/precision'] = truePos_count / (truePos_count + falsePos_count)
+        recall    = metrics_dict['pr/recall']    = truePos_count / (truePos_count + falseNeg_count)
+
+        metrics_dict['pr/f1_score'] = 2 * (precision * recall) / (precision + recall)
+
+        log.info(
+            ("E{} {:8} "
+                 + "{loss/all:.4f} loss, "
+                 + "{correct/all:-5.1f}% correct, "
+                 + "{pr/precision:.4f} precision, "
+                 + "{pr/recall:.4f} recall, "
+                 + "{pr/f1_score:.4f} f1 score"
+            ).format(
+                epoch_ndx,
+                mode_str,
+                **metrics_dict,
+            )
+        )
+        log.info(
+            ("E{} {:8} "
+                 + "{loss/ben:.4f} loss, "
+                 + "{correct/ben:-5.1f}% correct ({benCorrect_count:} of {benLabel_count:})"
+            ).format(
+                epoch_ndx,
+                mode_str + '_ben',
+                benCorrect_count=benCorrect_count,
+                benLabel_count=benLabel_count,
+                **metrics_dict,
+            )
+        )
+        log.info(
+            ("E{} {:8} "
+                 + "{loss/mal:.4f} loss, "
+                 + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
+            ).format(
+                epoch_ndx,
+                mode_str + '_mal',
+                malCorrect_count=malCorrect_count,
+                malLabel_count=malLabel_count,
+                **metrics_dict,
+            )
+        )
+        writer = getattr(self, mode_str + '_writer')
+
+        for key, value in metrics_dict.items():
+            writer.add_scalar(key, value, self.totalTrainingSamples_count)
+
+        writer.add_pr_curve(
+            'pr',
+            metrics_ary[METRICS_LABEL_NDX],
+            metrics_ary[METRICS_PRED_NDX],
+            self.totalTrainingSamples_count,
+        )
+
+        bins = [x/50.0 for x in range(51)]
+
+        benHist_mask = benLabel_mask & (metrics_ary[METRICS_PRED_NDX] > 0.01)
+        malHist_mask = malLabel_mask & (metrics_ary[METRICS_PRED_NDX] < 0.99)
+
+        if benHist_mask.any():
+            writer.add_histogram(
+                'is_ben',
+                metrics_ary[METRICS_PRED_NDX, benHist_mask],
+                self.totalTrainingSamples_count,
+                bins=bins,
+            )
+        if malHist_mask.any():
+            writer.add_histogram(
+                'is_mal',
+                metrics_ary[METRICS_PRED_NDX, malHist_mask],
+                self.totalTrainingSamples_count,
+                bins=bins,
+            )
+
+        score = 1 \
+            + metrics_dict['pr/f1_score'] \
+            - metrics_dict['loss/mal'] * 0.01 \
+            - metrics_dict['loss/all'] * 0.0001
+
+        return score
+
+    def saveModel(self, type_str, epoch_ndx, isBest=False):
+        file_path = os.path.join(
+            'data-unversioned',
+            'part2',
+            'models',
+            self.cli_args.tb_prefix,
+            '{}_{}_{}.{}.state'.format(
+                type_str,
+                self.time_str,
+                self.cli_args.comment,
+                self.totalTrainingSamples_count,
+            )
+        )
+
+        os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
+
+        model = self.model
+        if hasattr(model, 'module'):
+            model = model.module
+
+        state = {
+            'model_state': model.state_dict(),
+            'model_name': type(model).__name__,
+            'optimizer_state' : self.optimizer.state_dict(),
+            'optimizer_name': type(self.optimizer).__name__,
+            'epoch': epoch_ndx,
+            'totalTrainingSamples_count': self.totalTrainingSamples_count,
+            # 'resumed_from': self.cli_args.resume,
+        }
+        torch.save(state, file_path)
+
+        log.debug("Saved model params to {}".format(file_path))
+
+        if isBest:
+            file_path = os.path.join(
+                'data-unversioned',
+                'part2',
+                'models',
+                self.cli_args.tb_prefix,
+                '{}_{}_{}.{}.state'.format(
+                    type_str,
+                    self.time_str,
+                    self.cli_args.comment,
+                    'best',
+                )
+            )
+            torch.save(state, file_path)
+
+            log.debug("Saved model params to {}".format(file_path))
+
+if __name__ == '__main__':
+    sys.exit(LunaTrainingApp().main() or 0)

+ 538 - 0
p2ch12/train_seg.py

@@ -0,0 +1,538 @@
+import argparse
+import datetime
+import os
+import socket
+import sys
+
+import numpy as np
+from tensorboardX import SummaryWriter
+
+import torch
+import torch.nn as nn
+import torch.optim
+
+from torch.optim import SGD, Adam
+from torch.utils.data import DataLoader
+
+from util.util import enumerateWithEstimate
+from .dsets import Luna2dSegmentationDataset, TrainingLuna2dSegmentationDataset, getCt
+from util.logconf import logging
+from util.util import xyz2irc
+from .model_seg import UNetWrapper
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+# Used for computeClassificationLoss and logMetrics to index into metrics_tensor/metrics_ary
+METRICS_LABEL_NDX = 0
+METRICS_LOSS_NDX = 1
+METRICS_MAL_LOSS_NDX = 2
+# METRICS_ALL_LOSS_NDX = 3
+
+METRICS_MTP_NDX = 4
+METRICS_MFN_NDX = 5
+METRICS_MFP_NDX = 6
+METRICS_ATP_NDX = 7
+METRICS_AFN_NDX = 8
+METRICS_AFP_NDX = 9
+
+METRICS_SIZE = 10
+
+class LunaTrainingApp(object):
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=16,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+        parser.add_argument('--epochs',
+            help='Number of epochs to train for',
+            default=1,
+            type=int,
+        )
+
+        parser.add_argument('--augmented',
+            help="Augment the training data.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-flip',
+            help="Augment the training data by randomly flipping the data left-right, up-down, and front-back.",
+            action='store_true',
+            default=False,
+        )
+        # parser.add_argument('--augment-offset',
+        #     help="Augment the training data by randomly offsetting the data slightly along the X and Y axes.",
+        #     action='store_true',
+        #     default=False,
+        # )
+        # parser.add_argument('--augment-scale',
+        #     help="Augment the training data by randomly increasing or decreasing the size of the nodule.",
+        #     action='store_true',
+        #     default=False,
+        # )
+        parser.add_argument('--augment-rotate',
+            help="Augment the training data by randomly rotating the data around the head-foot axis.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-noise',
+            help="Augment the training data by randomly adding noise to the data.",
+            action='store_true',
+            default=False,
+        )
+
+        parser.add_argument('--tb-prefix',
+            default='p2ch12',
+            help="Data prefix to use for Tensorboard run. Defaults to chapter.",
+        )
+
+        parser.add_argument('comment',
+            help="Comment suffix for Tensorboard run.",
+            nargs='?',
+            default='none',
+        )
+
+        self.cli_args = parser.parse_args(sys_argv)
+        self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
+        self.totalTrainingSamples_count = 0
+        self.trn_writer = None
+        self.tst_writer = None
+
+        augmentation_dict = {}
+        if self.cli_args.augmented or self.cli_args.augment_flip:
+            augmentation_dict['flip'] = True
+        if self.cli_args.augmented or self.cli_args.augment_rotate:
+            augmentation_dict['rotate'] = True
+        if self.cli_args.augmented or self.cli_args.augment_noise:
+            augmentation_dict['noise'] = 25.0
+        self.augmentation_dict = augmentation_dict
+
+        self.use_cuda = torch.cuda.is_available()
+        self.device = torch.device("cuda" if self.use_cuda else "cpu")
+
+        self.model = self.initModel()
+        self.optimizer = self.initOptimizer()
+
+
+    def initModel(self):
+        model = UNetWrapper(in_channels=8, n_classes=1, depth=4, wf=3, padding=True, batch_norm=True, up_mode='upconv')
+
+        if self.use_cuda:
+            if torch.cuda.device_count() > 1:
+                model = nn.DataParallel(model)
+            model = model.to(self.device)
+        return model
+
+    def initOptimizer(self):
+        return SGD(self.model.parameters(), lr=0.001, momentum=0.99)
+        # return Adam(self.model.parameters())
+
+
+    def initTrainDl(self):
+        train_ds = TrainingLuna2dSegmentationDataset(
+            test_stride=10,
+            isTestSet_bool=False,
+            contextSlices_count=3,
+            augmentation_dict=self.augmentation_dict,
+        )
+
+        train_dl = DataLoader(
+            train_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return train_dl
+
+    def initTestDl(self):
+        test_ds = Luna2dSegmentationDataset(
+            test_stride=10,
+            isTestSet_bool=True,
+            contextSlices_count=3,
+        )
+
+        test_dl = DataLoader(
+            test_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return test_dl
+
+    def initTensorboardWriters(self):
+        if self.trn_writer is None:
+            log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
+
+            self.trn_writer = SummaryWriter(log_dir=log_dir + '_trn_seg_' + self.cli_args.comment)
+            self.tst_writer = SummaryWriter(log_dir=log_dir + '_tst_seg_' + self.cli_args.comment)
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        train_dl = self.initTrainDl()
+        test_dl = self.initTestDl()
+
+        # self.logModelMetrics(self.model)
+
+        best_score = 0.0
+        for epoch_ndx in range(1, self.cli_args.epochs + 1):
+            log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
+                epoch_ndx,
+                self.cli_args.epochs,
+                len(train_dl),
+                len(test_dl),
+                self.cli_args.batch_size,
+                (torch.cuda.device_count() if self.use_cuda else 1),
+            ))
+
+            trainingMetrics_tensor = self.doTraining(epoch_ndx, train_dl)
+            self.logMetrics(epoch_ndx, 'trn', trainingMetrics_tensor)
+            self.logImages(epoch_ndx, 'trn', train_dl)
+            self.logImages(epoch_ndx, 'tst', test_dl)
+            # self.logModelMetrics(self.model)
+
+            testingMetrics_tensor = self.doTesting(epoch_ndx, test_dl)
+            score = self.logMetrics(epoch_ndx, 'tst', testingMetrics_tensor)
+            best_score = max(score, best_score)
+
+            self.saveModel('seg', epoch_ndx, score == best_score)
+
+        if hasattr(self, 'trn_writer'):
+            self.trn_writer.close()
+            self.tst_writer.close()
+
+    def doTraining(self, epoch_ndx, train_dl):
+        trainingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(train_dl.dataset)).to(self.device)
+        self.model.train()
+
+        batch_iter = enumerateWithEstimate(
+            train_dl,
+            "E{} Training".format(epoch_ndx),
+            start_ndx=train_dl.num_workers,
+        )
+        for batch_ndx, batch_tup in batch_iter:
+            self.optimizer.zero_grad()
+
+            loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_devtensor)
+            loss_var.backward()
+
+            self.optimizer.step()
+            del loss_var
+
+        self.totalTrainingSamples_count += trainingMetrics_devtensor.size(1)
+
+        return trainingMetrics_devtensor.to('cpu')
+
+    def doTesting(self, epoch_ndx, test_dl):
+        with torch.no_grad():
+            testingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset)).to(self.device)
+            self.model.eval()
+
+            batch_iter = enumerateWithEstimate(
+                test_dl,
+                "E{} Testing ".format(epoch_ndx),
+                start_ndx=test_dl.num_workers,
+            )
+            for batch_ndx, batch_tup in batch_iter:
+                self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_devtensor)
+
+        return testingMetrics_devtensor.to('cpu')
+
+    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_devtensor):
+        input_tensor, label_tensor, label_list, ben_tensor, mal_tensor, _series_list, _start_list = batch_tup
+
+        input_devtensor = input_tensor.to(self.device, non_blocking=True)
+        label_devtensor = label_tensor.to(self.device, non_blocking=True)
+        mal_devtensor = mal_tensor.to(self.device, non_blocking=True)
+        ben_devtensor = ben_tensor.to(self.device, non_blocking=True)
+
+        start_ndx = batch_ndx * batch_size
+        end_ndx = start_ndx + label_tensor.size(0)
+        intersectionSum = lambda a, b: (a * b).view(a.size(0), -1).sum(dim=1)
+
+        prediction_devtensor = self.model(input_devtensor)
+        diceLoss_devtensor = self.diceLoss(label_devtensor, prediction_devtensor)
+
+        with torch.no_grad():
+            predictionBool_devtensor = (prediction_devtensor > 0.5).to(torch.float32)
+
+            metrics_devtensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_list
+            metrics_devtensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_devtensor
+
+            # benPred_devtensor = predictionBool_devtensor * (1 - mal_devtensor)
+            tp = intersectionSum(    label_devtensor,     predictionBool_devtensor)
+            fn = intersectionSum(    label_devtensor, 1 - predictionBool_devtensor)
+            fp = intersectionSum(1 - label_devtensor,     predictionBool_devtensor)
+            # ls = self.diceLoss(label_devtensor, benPred_devtensor)
+
+            metrics_devtensor[METRICS_ATP_NDX, start_ndx:end_ndx] = tp
+            metrics_devtensor[METRICS_AFN_NDX, start_ndx:end_ndx] = fn
+            metrics_devtensor[METRICS_AFP_NDX, start_ndx:end_ndx] = fp
+            # metrics_devtensor[METRICS_ALL_LOSS_NDX, start_ndx:end_ndx] = ls
+
+            del tp, fn, fp
+
+            malPred_devtensor = predictionBool_devtensor * (1 - ben_devtensor)
+            tp = intersectionSum(    mal_devtensor,       malPred_devtensor)
+            fn = intersectionSum(    mal_devtensor,   1 - malPred_devtensor)
+            fp = intersectionSum(1 - label_devtensor,     malPred_devtensor)
+            ls = self.diceLoss(mal_devtensor, malPred_devtensor)
+
+            metrics_devtensor[METRICS_MTP_NDX, start_ndx:end_ndx] = tp
+            metrics_devtensor[METRICS_MFN_NDX, start_ndx:end_ndx] = fn
+            # metrics_devtensor[METRICS_MFP_NDX, start_ndx:end_ndx] = fp
+            metrics_devtensor[METRICS_MAL_LOSS_NDX, start_ndx:end_ndx] = ls
+
+            del malPred_devtensor, tp, fn, fp, ls
+
+        return diceLoss_devtensor.mean()
+
+    # def diceLoss(self, label_devtensor, prediction_devtensor, epsilon=0.01, p=False):
+    def diceLoss(self, label_devtensor, prediction_devtensor, epsilon=1024, p=False):
+        sum_dim1 = lambda t: t.view(t.size(0), -1).sum(dim=1)
+
+        diceLabel_devtensor = sum_dim1(label_devtensor)
+        dicePrediction_devtensor = sum_dim1(prediction_devtensor)
+        diceCorrect_devtensor = sum_dim1(prediction_devtensor * label_devtensor)
+
+        epsilon_devtensor = torch.ones_like(diceCorrect_devtensor) * epsilon
+        diceLoss_devtensor = 1 - (2 * diceCorrect_devtensor + epsilon_devtensor) / (dicePrediction_devtensor + diceLabel_devtensor + epsilon_devtensor)
+
+        if p:
+            log.debug([])
+            log.debug(['diceCorrect_devtensor   ', diceCorrect_devtensor[0].item()])
+            log.debug(['dicePrediction_devtensor', dicePrediction_devtensor[0].item()])
+            log.debug(['diceLabel_devtensor     ', diceLabel_devtensor[0].item()])
+            log.debug(['2*diceCorrect_devtensor ', 2 * diceCorrect_devtensor[0].item()])
+            log.debug(['Prediction + Label      ', dicePrediction_devtensor[0].item() + diceLabel_devtensor[0].item()])
+            log.debug(['diceLoss_devtensor      ', diceLoss_devtensor[0].item()])
+
+        return diceLoss_devtensor
+
+
+    def logImages(self, epoch_ndx, mode_str, dl):
+        for i, series_uid in enumerate(sorted(dl.dataset.series_list)[:12]):
+            ct = getCt(series_uid)
+
+            for slice_ndx in range(0, ct.ary.shape[0], ct.ary.shape[0] // 5):
+                sample_tup = dl.dataset[(series_uid, slice_ndx, False)]
+
+                ct_tensor, nodule_tensor, label_int, ben_tensor, mal_tensor, series_uid, ct_ndx = sample_tup
+
+                ct_tensor[:-1,:,:] += 1000
+                ct_tensor[:-1,:,:] /= 2000
+
+                input_devtensor = ct_tensor.to(self.device)
+                label_devtensor = nodule_tensor.to(self.device)
+
+                prediction_devtensor = self.model(input_devtensor.unsqueeze(0))[0]
+                prediction_ary = prediction_devtensor.to('cpu').detach().numpy()
+                label_ary = nodule_tensor.numpy()
+                ben_ary = ben_tensor.numpy()
+                mal_ary = mal_tensor.numpy()
+
+                image_ary = np.zeros((512, 512, 3), dtype=np.float32)
+                image_ary[:,:,:] = (ct_tensor[dl.dataset.contextSlices_count].numpy().reshape((512,512,1)))
+                image_ary[:,:,0] += prediction_ary[0] * (1 - label_ary[0])  # Red
+                image_ary[:,:,1] += prediction_ary[0] * mal_ary  # Green
+                image_ary[:,:,2] += prediction_ary[0] * ben_ary  # Blue
+
+                writer = getattr(self, mode_str + '_writer')
+                image_ary *= 0.5
+                image_ary[image_ary < 0] = 0
+                image_ary[image_ary > 1] = 1
+                writer.add_image('{}/{}_prediction_{}'.format(mode_str, i, slice_ndx), image_ary, self.totalTrainingSamples_count, dataformats='HWC')
+
+                # self.diceLoss(label_devtensor, prediction_devtensor, p=True)
+
+                if epoch_ndx == 1:
+                    image_ary = np.zeros((512, 512, 3), dtype=np.float32)
+                    image_ary[:,:,:] = (ct_tensor[dl.dataset.contextSlices_count].numpy().reshape((512,512,1)))
+                    image_ary[:,:,0] += (1 - label_ary[0]) * ct_tensor[-1].numpy() # Red
+                    image_ary[:,:,1] += mal_ary  # Green
+                    image_ary[:,:,2] += ben_ary  # Blue
+
+                    writer = getattr(self, mode_str + '_writer')
+                    image_ary *= 0.5
+                    image_ary[image_ary < 0] = 0
+                    image_ary[image_ary > 1] = 1
+                    writer.add_image('{}/{}_label_{}'.format(mode_str, i, slice_ndx), image_ary, self.totalTrainingSamples_count, dataformats='HWC')
+
+
+    def logMetrics(self,
+        epoch_ndx,
+        mode_str,
+        metrics_tensor,
+    ):
+        self.initTensorboardWriters()
+        log.info("E{} {}".format(
+            epoch_ndx,
+            type(self).__name__,
+        ))
+
+        metrics_ary = metrics_tensor.cpu().detach().numpy()
+        sum_ary = metrics_ary.sum(axis=1)
+        assert np.isfinite(metrics_ary).all()
+
+        malLabel_mask = (metrics_ary[METRICS_LABEL_NDX] == 1) | (metrics_ary[METRICS_LABEL_NDX] == 3)
+
+        # allLabel_mask = (metrics_ary[METRICS_LABEL_NDX] == 2) | (metrics_ary[METRICS_LABEL_NDX] == 3)
+
+        allLabel_count = sum_ary[METRICS_ATP_NDX] + sum_ary[METRICS_AFN_NDX]
+        malLabel_count = sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]
+
+        allCorrect_count = sum_ary[METRICS_ATP_NDX]
+        malCorrect_count = sum_ary[METRICS_MTP_NDX]
+#
+#             falsePos_count = allLabel_count - allCorrect_count
+#             falseNeg_count = malLabel_count - malCorrect_count
+
+
+        metrics_dict = {}
+        metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
+        metrics_dict['loss/mal'] = np.nan_to_num(metrics_ary[METRICS_MAL_LOSS_NDX, malLabel_mask].mean())
+        # metrics_dict['loss/all'] = metrics_ary[METRICS_ALL_LOSS_NDX, allLabel_mask].mean()
+
+        metrics_dict['correct/mal'] = sum_ary[METRICS_MTP_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
+        metrics_dict['correct/all'] = sum_ary[METRICS_ATP_NDX] / (sum_ary[METRICS_ATP_NDX] + sum_ary[METRICS_AFN_NDX]) * 100
+
+        precision = metrics_dict['pr/precision'] = sum_ary[METRICS_ATP_NDX] / ((sum_ary[METRICS_ATP_NDX] + sum_ary[METRICS_AFP_NDX]) or 1)
+        recall    = metrics_dict['pr/recall']    = sum_ary[METRICS_ATP_NDX] / ((sum_ary[METRICS_ATP_NDX] + sum_ary[METRICS_AFN_NDX]) or 1)
+
+        metrics_dict['pr/f1_score'] = 2 * (precision * recall) / ((precision + recall) or 1)
+
+        log.info(("E{} {:8} "
+                 + "{loss/all:.4f} loss, "
+                 + "{pr/precision:.4f} precision, "
+                 + "{pr/recall:.4f} recall, "
+                 + "{pr/f1_score:.4f} f1 score"
+                  ).format(
+            epoch_ndx,
+            mode_str,
+            **metrics_dict,
+        ))
+        log.info(("E{} {:8} "
+                 + "{loss/all:.4f} loss, "
+                 + "{correct/all:-5.1f}% correct ({allCorrect_count:} of {allLabel_count:})"
+        ).format(
+            epoch_ndx,
+            mode_str + '_all',
+            allCorrect_count=allCorrect_count,
+            allLabel_count=allLabel_count,
+            **metrics_dict,
+        ))
+
+        log.info(("E{} {:8} "
+                 + "{loss/mal:.4f} loss, "
+                 + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
+        ).format(
+            epoch_ndx,
+            mode_str + '_mal',
+            malCorrect_count=malCorrect_count,
+            malLabel_count=malLabel_count,
+            **metrics_dict,
+        ))
+        writer = getattr(self, mode_str + '_writer')
+
+        prefix_str = 'seg_'
+
+        for key, value in metrics_dict.items():
+            writer.add_scalar(prefix_str + key, value, self.totalTrainingSamples_count)
+
+        score = 1 \
+            + metrics_dict['pr/f1_score'] \
+            - metrics_dict['pr/recall'] * 0.01 \
+            - metrics_dict['loss/mal']  * 0.001 \
+            - metrics_dict['loss/all']  * 0.0001
+
+        return score
+
+    # def logModelMetrics(self, model):
+    #     writer = getattr(self, 'trn_writer')
+    #
+    #     model = getattr(model, 'module', model)
+    #
+    #     for name, param in model.named_parameters():
+    #         if param.requires_grad:
+    #             min_data = float(param.data.min())
+    #             max_data = float(param.data.max())
+    #             max_extent = max(abs(min_data), abs(max_data))
+    #
+    #             # bins = [x/50*max_extent for x in range(-50, 51)]
+    #
+    #             writer.add_histogram(
+    #                 name.rsplit('.', 1)[-1] + '/' + name,
+    #                 param.data.cpu().numpy(),
+    #                 # metrics_ary[METRICS_PRED_NDX, benHist_mask],
+    #                 self.totalTrainingSamples_count,
+    #                 # bins=bins,
+    #             )
+    #
+    #             # print name, param.data
+
+    def saveModel(self, type_str, epoch_ndx, isBest=False):
+        file_path = os.path.join(
+            'data-unversioned',
+            'part2',
+            'models',
+            self.cli_args.tb_prefix,
+            '{}_{}_{}.{}.state'.format(
+                type_str,
+                self.time_str,
+                self.cli_args.comment,
+                self.totalTrainingSamples_count,
+            )
+        )
+
+        os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
+
+        model = self.model
+        if hasattr(model, 'module'):
+            model = model.module
+
+        state = {
+            'model_state': model.state_dict(),
+            'model_name': type(model).__name__,
+            'optimizer_state' : self.optimizer.state_dict(),
+            'optimizer_name': type(self.optimizer).__name__,
+            'epoch': epoch_ndx,
+            'totalTrainingSamples_count': self.totalTrainingSamples_count,
+        }
+        torch.save(state, file_path)
+
+        log.debug("Saved model params to {}".format(file_path))
+
+        if isBest:
+            file_path = os.path.join(
+                'data-unversioned',
+                'part2',
+                'models',
+                self.cli_args.tb_prefix,
+                '{}_{}_{}.{}.state'.format(
+                    type_str,
+                    self.time_str,
+                    self.cli_args.comment,
+                    'best',
+                )
+            )
+            torch.save(state, file_path)
+
+            log.debug("Saved model params to {}".format(file_path))
+
+
+if __name__ == '__main__':
+    sys.exit(LunaTrainingApp().main() or 0)

+ 554 - 0
p2ch12/training.py

@@ -0,0 +1,554 @@
+import argparse
+import datetime
+import os
+import socket
+import sys
+
+import numpy as np
+from tensorboardX import SummaryWriter
+
+import torch
+import torch.nn as nn
+import torch.optim
+
+from torch.optim import SGD, Adam
+from torch.utils.data import DataLoader
+
+from util.util import enumerateWithEstimate
+from .dsets import Luna2dSegmentationDataset, TrainingLuna2dSegmentationDataset, getCt
+from util.logconf import logging
+from util.util import xyz2irc
+from .model import UNetWrapper
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+# Used for computeClassificationLoss and logMetrics to index into metrics_tensor/metrics_ary
+METRICS_LABEL_NDX = 0
+METRICS_LOSS_NDX = 1
+METRICS_MAL_LOSS_NDX = 2
+METRICS_BEN_LOSS_NDX = 3
+
+METRICS_MTP_NDX = 4
+METRICS_MFN_NDX = 5
+METRICS_MFP_NDX = 6
+METRICS_BTP_NDX = 7
+METRICS_BFN_NDX = 8
+# METRICS_BFP_NDX = 9
+
+METRICS_SIZE = 9
+
+class LunaTrainingApp(object):
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=24,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+        parser.add_argument('--epochs',
+            help='Number of epochs to train for',
+            default=1,
+            type=int,
+        )
+
+        parser.add_argument('--augmented',
+            help="Augment the training data.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-flip',
+            help="Augment the training data by randomly flipping the data left-right, up-down, and front-back.",
+            action='store_true',
+            default=False,
+        )
+        # parser.add_argument('--augment-offset',
+        #     help="Augment the training data by randomly offsetting the data slightly along the X and Y axes.",
+        #     action='store_true',
+        #     default=False,
+        # )
+        # parser.add_argument('--augment-scale',
+        #     help="Augment the training data by randomly increasing or decreasing the size of the nodule.",
+        #     action='store_true',
+        #     default=False,
+        # )
+        parser.add_argument('--augment-rotate',
+            help="Augment the training data by randomly rotating the data around the head-foot axis.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-noise',
+            help="Augment the training data by randomly adding noise to the data.",
+            action='store_true',
+            default=False,
+        )
+
+        parser.add_argument('--tb-prefix',
+            default='p2ch12',
+            help="Data prefix to use for Tensorboard run. Defaults to chapter.",
+        )
+
+        parser.add_argument('comment',
+            help="Comment suffix for Tensorboard run.",
+            nargs='?',
+            default='none',
+        )
+
+        self.cli_args = parser.parse_args(sys_argv)
+        self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
+
+        self.trn_writer = None
+        self.tst_writer = None
+
+        self.use_cuda = torch.cuda.is_available()
+        self.device = torch.device("cuda" if self.use_cuda else "cpu")
+
+        # # TODO: remove this if block before print
+        # # This is due to an odd setup that the author is using to test the code; please ignore for now
+        # if socket.gethostname() == 'c2':
+        #     self.device = torch.device("cuda:1")
+
+        self.model = self.initModel()
+        self.optimizer = self.initOptimizer()
+
+        self.totalTrainingSamples_count = 0
+
+        augmentation_dict = {}
+        if self.cli_args.augmented or self.cli_args.augment_flip:
+            augmentation_dict['flip'] = True
+        if self.cli_args.augmented or self.cli_args.augment_rotate:
+            augmentation_dict['rotate'] = True
+        if self.cli_args.augmented or self.cli_args.augment_noise:
+            augmentation_dict['noise'] = 25.0
+        self.augmentation_dict = augmentation_dict
+
+
+    def initModel(self):
+        # model = UNetWrapper(in_channels=8, n_classes=2, depth=3, wf=6, padding=True, batch_norm=True, up_mode='upconv')
+        model = UNetWrapper(in_channels=7, n_classes=1, depth=4, wf=3, padding=True, batch_norm=True, up_mode='upconv')
+
+        if self.use_cuda:
+            if torch.cuda.device_count() > 1:
+                model = nn.DataParallel(model)
+            model = model.to(self.device)
+
+        return model
+
+    def initOptimizer(self):
+        return SGD(self.model.parameters(), lr=0.01, momentum=0.99)
+        # return Adam(self.model.parameters())
+
+
+    def initTrainDl(self):
+        train_ds = TrainingLuna2dSegmentationDataset(
+            test_stride=10,
+            contextSlices_count=3,
+            augmentation_dict=self.augmentation_dict,
+        )
+
+        train_dl = DataLoader(
+            train_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return train_dl
+
+    def initTestDl(self):
+        test_ds = Luna2dSegmentationDataset(
+            test_stride=10,
+            contextSlices_count=3,
+        )
+
+        test_dl = DataLoader(
+            test_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return test_dl
+
+    def initTensorboardWriters(self):
+        if self.trn_writer is None:
+            log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
+
+            self.trn_writer = SummaryWriter(log_dir=log_dir + '_trn_seg_' + self.cli_args.comment)
+            self.tst_writer = SummaryWriter(log_dir=log_dir + '_tst_seg_' + self.cli_args.comment)
+
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        train_dl = self.initTrainDl()
+        test_dl = self.initTestDl()
+
+        self.initTensorboardWriters()
+        # self.logModelMetrics(self.model)
+
+        best_score = 0.0
+
+        for epoch_ndx in range(1, self.cli_args.epochs + 1):
+            log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
+                epoch_ndx,
+                self.cli_args.epochs,
+                len(train_dl),
+                len(test_dl),
+                self.cli_args.batch_size,
+                (torch.cuda.device_count() if self.use_cuda else 1),
+            ))
+
+            trainingMetrics_tensor = self.doTraining(epoch_ndx, train_dl)
+            self.logPerformanceMetrics(epoch_ndx, 'trn', trainingMetrics_tensor)
+            self.logImages(epoch_ndx, train_dl, test_dl)
+            # self.logModelMetrics(self.model)
+
+            testingMetrics_tensor = self.doTesting(epoch_ndx, test_dl)
+            score = self.logPerformanceMetrics(epoch_ndx, 'tst', testingMetrics_tensor)
+            best_score = max(score, best_score)
+
+            self.saveModel('seg' if self.cli_args.segmentation else 'cls', epoch_ndx, score == best_score)
+
+        if hasattr(self, 'trn_writer'):
+            self.trn_writer.close()
+            self.tst_writer.close()
+
+    def doTraining(self, epoch_ndx, train_dl):
+        trainingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(train_dl.dataset)).to(self.device)
+        self.model.train()
+
+        batch_iter = enumerateWithEstimate(
+            train_dl,
+            "E{} Training".format(epoch_ndx),
+            start_ndx=train_dl.num_workers,
+        )
+        for batch_ndx, batch_tup in batch_iter:
+            self.optimizer.zero_grad()
+
+            loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_devtensor)
+            loss_var.backward()
+
+            self.optimizer.step()
+            del loss_var
+
+        self.totalTrainingSamples_count += trainingMetrics_devtensor.size(1)
+
+        return trainingMetrics_devtensor.to('cpu')
+
+    def doTesting(self, epoch_ndx, test_dl):
+        with torch.no_grad():
+            testingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset)).to(self.device)
+            self.model.eval()
+
+            batch_iter = enumerateWithEstimate(
+                test_dl,
+                "E{} Testing ".format(epoch_ndx),
+                start_ndx=test_dl.num_workers,
+            )
+            for batch_ndx, batch_tup in batch_iter:
+                self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_devtensor)
+
+        return testingMetrics_devtensor.to('cpu')
+
+
+    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_devtensor):
+        input_tensor, label_tensor, label_list, ben_tensor, mal_tensor, _series_list, _start_list = batch_tup
+
+        input_devtensor = input_tensor.to(self.device, non_blocking=True)
+        label_devtensor = label_tensor.to(self.device, non_blocking=True)
+        mal_devtensor = mal_tensor.to(self.device, non_blocking=True)
+        ben_devtensor = ben_tensor.to(self.device, non_blocking=True)
+
+        start_ndx = batch_ndx * batch_size
+        end_ndx = start_ndx + label_tensor.size(0)
+        intersectionSum = lambda a, b: (a * b).view(a.size(0), -1).sum(dim=1)
+
+        prediction_devtensor = self.model(input_devtensor)
+        diceLoss_devtensor = self.diceLoss(label_devtensor, prediction_devtensor)
+
+        with torch.no_grad():
+            predictionBool_devtensor = (prediction_devtensor > 0.5).to(torch.float32)
+
+            metrics_devtensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_list
+            metrics_devtensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_devtensor
+
+            malPred_devtensor = predictionBool_devtensor * (1 - ben_devtensor)
+
+            tp = intersectionSum(    mal_devtensor,     malPred_devtensor)
+            fn = intersectionSum(    mal_devtensor, 1 - malPred_devtensor)
+            fp = intersectionSum(1 - mal_devtensor,     malPred_devtensor)
+            ls = self.diceLoss(mal_devtensor, malPred_devtensor)
+
+            metrics_devtensor[METRICS_MTP_NDX, start_ndx:end_ndx] = tp
+            metrics_devtensor[METRICS_MFN_NDX, start_ndx:end_ndx] = fn
+            metrics_devtensor[METRICS_MFP_NDX, start_ndx:end_ndx] = fp
+            metrics_devtensor[METRICS_MAL_LOSS_NDX, start_ndx:end_ndx] = ls
+
+            del malPred_devtensor, tp, fn, fp, ls
+
+            benPred_devtensor = predictionBool_devtensor * (1 - mal_devtensor)
+            tp = intersectionSum(    ben_devtensor,     benPred_devtensor)
+            fn = intersectionSum(    ben_devtensor, 1 - benPred_devtensor)
+            # fp = intersectionSum(1 - ben_devtensor,     benPred_devtensor)
+            ls = self.diceLoss(ben_devtensor, benPred_devtensor)
+
+            metrics_devtensor[METRICS_BTP_NDX, start_ndx:end_ndx] = tp
+            metrics_devtensor[METRICS_BFN_NDX, start_ndx:end_ndx] = fn
+            # metrics_devtensor[METRICS_BFP_NDX, start_ndx:end_ndx] = fp
+            metrics_devtensor[METRICS_BEN_LOSS_NDX, start_ndx:end_ndx] = ls
+
+            del benPred_devtensor, tp, fn, ls
+
+        return diceLoss_devtensor.mean()
+
+    def diceLoss(self, label_devtensor, prediction_devtensor, epsilon=0.01, p=False):
+        sum_dim1 = lambda t: t.view(t.size(0), -1).sum(dim=1)
+
+        diceLabel_devtensor = sum_dim1(label_devtensor)
+        dicePrediction_devtensor = sum_dim1(prediction_devtensor)
+        diceCorrect_devtensor = sum_dim1(prediction_devtensor * label_devtensor)
+
+        epsilon_devtensor = torch.ones_like(diceCorrect_devtensor) * epsilon
+        diceLoss_devtensor = 1 - (2 * diceCorrect_devtensor + epsilon_devtensor) / (dicePrediction_devtensor + diceLabel_devtensor + epsilon_devtensor)
+
+        return diceLoss_devtensor
+
+
+
+    def logImages(self, epoch_ndx, train_dl, test_dl):
+        for mode_str, dl in [('trn', train_dl), ('tst', test_dl)]:
+            for i, series_uid in enumerate(sorted(dl.dataset.series_list)[:12]):
+                ct = getCt(series_uid)
+                noduleInfo_tup = (ct.malignantInfo_list or ct.benignInfo_list)[0]
+                center_irc = xyz2irc(noduleInfo_tup.center_xyz, ct.origin_xyz, ct.vxSize_xyz, ct.direction_tup)
+
+                sample_tup = dl.dataset[(series_uid, int(center_irc.index))]
+                # input_tensor = sample_tup[0].unsqueeze(0)
+                # label_tensor = sample_tup[1].unsqueeze(0)
+
+                input_tensor, label_tensor, ben_tensor, mal_tensor = sample_tup[:4]
+                input_tensor += 1000
+                input_tensor /= 2001
+
+                input_devtensor = input_tensor.to(self.device)
+                # label_devtensor = label_tensor.to(self.device)
+
+                prediction_devtensor = self.model(input_devtensor.unsqueeze(0))[0]
+                prediction_ary = prediction_devtensor.to('cpu').detach().numpy()
+                label_ary = label_tensor.numpy()
+                ben_ary = ben_tensor.numpy()
+                mal_ary = mal_tensor.numpy()
+
+                # log.debug([prediction_ary[0].shape, label_ary.shape, mal_ary.shape])
+
+                image_ary = np.zeros((512, 512, 3), dtype=np.float32)
+                image_ary[:,:,:] = (input_tensor[dl.dataset.contextSlices_count].numpy().reshape((512,512,1))) * 0.5
+                image_ary[:,:,0] += prediction_ary[0] * (1 - label_ary[0]) * 0.5
+                image_ary[:,:,1] += prediction_ary[0] * mal_ary * 0.5
+                image_ary[:,:,2] += prediction_ary[0] * ben_ary * 0.5
+                # image_ary[:,:,2] += prediction_ary[0,1] * 0.25
+                # image_ary[:,:,2] += prediction_ary[0,2] * 0.5
+
+                # log.debug([image_ary.__array_interface__['typestr']])
+
+                # image_ary = (image_ary * 255).astype(np.uint8)
+
+                # log.debug([image_ary.__array_interface__['typestr']])
+
+                writer = getattr(self, mode_str + '_writer')
+                try:
+                    image_ary[image_ary < 0] = 0
+                    image_ary[image_ary > 1] = 1
+                    writer.add_image('{}/{}_pred'.format(mode_str, i), image_ary, self.totalTrainingSamples_count, dataformats='HWC')
+                except:
+                    log.debug([image_ary.shape, image_ary.dtype])
+                    raise
+
+                if epoch_ndx == 1:
+
+                    image_ary = np.zeros((512, 512, 3), dtype=np.float32)
+                    image_ary[:,:,:] = (input_tensor[dl.dataset.contextSlices_count].numpy().reshape((512,512,1))) * 0.5
+                    image_ary[:,:,1] += mal_ary * 0.5
+                    image_ary[:,:,2] += ben_ary * 0.5
+                    # image_ary[:,:,2] += label_ary[0,1] * 0.25
+                    # image_ary[:,:,2] += (input_tensor[0,-1].numpy() - (label_ary[0,0].astype(np.bool) | label_ary[0,1].astype(np.bool))) * 0.25
+
+                    # log.debug([image_ary.__array_interface__['typestr']])
+
+                    # image_ary = (image_ary * 255).astype(np.uint8)
+
+                    # log.debug([image_ary.__array_interface__['typestr']])
+
+                    writer = getattr(self, mode_str + '_writer')
+                    image_ary[image_ary < 0] = 0
+                    image_ary[image_ary > 1] = 1
+                    writer.add_image('{}/{}_label'.format(mode_str, i), image_ary, self.totalTrainingSamples_count, dataformats='HWC')
+
+
+    def logPerformanceMetrics(self,
+                              epoch_ndx,
+                              mode_str,
+                              metrics_tensor,
+                              ):
+        log.info("E{} {}".format(
+            epoch_ndx,
+            type(self).__name__,
+        ))
+
+
+        # for mode_str, metrics_tensor in [('trn', trainingMetrics_tensor), ('tst', testingMetrics_tensor)]:
+        metrics_ary = metrics_tensor.cpu().detach().numpy()
+        sum_ary = metrics_ary.sum(axis=1)
+        assert np.isfinite(metrics_ary).all()
+
+        malLabel_mask = (metrics_ary[METRICS_LABEL_NDX] == 1) | (metrics_ary[METRICS_LABEL_NDX] == 3)
+
+        benLabel_mask = (metrics_ary[METRICS_LABEL_NDX] == 2) | (metrics_ary[METRICS_LABEL_NDX] == 3)
+        # malFound_mask = metrics_ary[METRICS_MFOUND_NDX] > classificationThreshold_float
+
+        # malLabel_mask = ~benLabel_mask
+        # malPred_mask = ~benPred_mask
+
+        benLabel_count = sum_ary[METRICS_BTP_NDX] + sum_ary[METRICS_BFN_NDX]
+        malLabel_count = sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]
+
+        trueNeg_count = benCorrect_count = sum_ary[METRICS_BTP_NDX]
+        truePos_count = malCorrect_count = sum_ary[METRICS_MTP_NDX]
+#
+#             falsePos_count = benLabel_count - benCorrect_count
+#             falseNeg_count = malLabel_count - malCorrect_count
+
+
+        metrics_dict = {}
+        metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
+        # metrics_dict['loss/msk'] = metrics_ary[METRICS_MASKLOSS_NDX].mean()
+        # metrics_dict['loss/mal'] = metrics_ary[METRICS_MALLOSS_NDX].mean()
+        # metrics_dict['loss/lng'] = metrics_ary[METRICS_LUNG_LOSS_NDX, benLabel_mask].mean()
+        metrics_dict['loss/mal'] = np.nan_to_num(metrics_ary[METRICS_MAL_LOSS_NDX, malLabel_mask].mean())
+        metrics_dict['loss/ben'] = metrics_ary[METRICS_BEN_LOSS_NDX, benLabel_mask].mean()
+        # metrics_dict['loss/flg'] = metrics_ary[METRICS_FLG_LOSS_NDX].mean()
+
+        # metrics_dict['flagged/all'] = sum_ary[METRICS_MOK_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
+        # metrics_dict['flagged/slices'] = (malLabel_mask & malFound_mask).sum() / malLabel_mask.sum() * 100
+
+        metrics_dict['correct/mal'] = sum_ary[METRICS_MTP_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
+        metrics_dict['correct/ben'] = sum_ary[METRICS_BTP_NDX] / (sum_ary[METRICS_BTP_NDX] + sum_ary[METRICS_BFN_NDX]) * 100
+
+        precision = metrics_dict['pr/precision'] = sum_ary[METRICS_MTP_NDX] / ((sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFP_NDX]) or 1)
+        recall    = metrics_dict['pr/recall']    = sum_ary[METRICS_MTP_NDX] / ((sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) or 1)
+
+        metrics_dict['pr/f1_score'] = 2 * (precision * recall) / ((precision + recall) or 1)
+
+        log.info(("E{} {:8} "
+                 + "{loss/all:.4f} loss, "
+                 # + "{loss/flg:.4f} flagged loss, "
+                 # + "{flagged/all:-5.1f}% pixels flagged, "
+                 # + "{flagged/slices:-5.1f}% slices flagged, "
+                 + "{pr/precision:.4f} precision, "
+                 + "{pr/recall:.4f} recall, "
+                 + "{pr/f1_score:.4f} f1 score"
+                  ).format(
+            epoch_ndx,
+            mode_str,
+            **metrics_dict,
+        ))
+        log.info(("E{} {:8} "
+                 + "{loss/mal:.4f} loss, "
+                 + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
+        ).format(
+            epoch_ndx,
+            mode_str + '_mal',
+            malCorrect_count=malCorrect_count,
+            malLabel_count=malLabel_count,
+            **metrics_dict,
+        ))
+        log.info(("E{} {:8} "
+                 + "{loss/ben:.4f} loss, "
+                 + "{correct/ben:-5.1f}% correct ({benCorrect_count:} of {benLabel_count:})"
+        ).format(
+            epoch_ndx,
+            mode_str + '_ben',
+            benCorrect_count=benCorrect_count,
+            benLabel_count=benLabel_count,
+            **metrics_dict,
+        ))
+
+        writer = getattr(self, mode_str + '_writer')
+
+        prefix_str = 'seg_'
+
+        for key, value in metrics_dict.items():
+            writer.add_scalar(prefix_str + key, value, self.totalTrainingSamples_count)
+
+        score = 1 \
+            + metrics_dict['pr/f1_score'] \
+            - metrics_dict['loss/mal'] * 0.01 \
+            - metrics_dict['loss/all'] * 0.0001
+
+        return score
+
+    # def logModelMetrics(self, model):
+    #     writer = getattr(self, 'trn_writer')
+    #
+    #     model = getattr(model, 'module', model)
+    #
+    #     for name, param in model.named_parameters():
+    #         if param.requires_grad:
+    #             min_data = float(param.data.min())
+    #             max_data = float(param.data.max())
+    #             max_extent = max(abs(min_data), abs(max_data))
+    #
+    #             # bins = [x/50*max_extent for x in range(-50, 51)]
+    #
+    #             writer.add_histogram(
+    #                 name.rsplit('.', 1)[-1] + '/' + name,
+    #                 param.data.cpu().numpy(),
+    #                 # metrics_ary[METRICS_PRED_NDX, benHist_mask],
+    #                 self.totalTrainingSamples_count,
+    #                 # bins=bins,
+    #             )
+    #
+    #             # print name, param.data
+
+    def saveModel(self, type_str, epoch_ndx, isBest=False):
+        file_path = os.path.join('data-unversioned', 'models', self.cli_args.tb_prefix, '{}_{}_{}.{}.state'.format(type_str, self.time_str, self.cli_args.comment, self.totalTrainingSamples_count))
+
+        os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
+
+        model = self.model
+        if hasattr(model, 'module'):
+            model = model.module
+
+        state = {
+            'model_state': model.state_dict(),
+            'model_name': type(model).__name__,
+            'optimizer_state' : self.optimizer.state_dict(),
+            'optimizer_name': type(self.optimizer).__name__,
+            'epoch': epoch_ndx,
+            'totalTrainingSamples_count': self.totalTrainingSamples_count,
+            # 'resumed_from': self.cli_args.resume,
+        }
+        torch.save(state, file_path)
+
+        log.debug("Saved model params to {}".format(file_path))
+
+        if isBest:
+            file_path = os.path.join('data-unversioned', 'models', self.cli_args.tb_prefix, '{}_{}_{}.{}.state'.format(type_str, self.time_str, self.cli_args.comment, 'best'))
+            torch.save(state, file_path)
+
+            log.debug("Saved model params to {}".format(file_path))
+
+
+if __name__ == '__main__':
+    sys.exit(LunaTrainingApp().main() or 0)

+ 86 - 0
p2ch12/vis.py

@@ -0,0 +1,86 @@
+import matplotlib
+matplotlib.use('nbagg')
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+from p2ch12.dsets import Ct, LunaDataset
+
+clim=(0.0, 1.3)
+
+def findMalignantSamples(start_ndx=0, limit=10):
+    ds = LunaDataset()
+
+    malignantSample_list = []
+    for sample_tup in ds.sample_list:
+        if sample_tup[2]:
+            print(len(malignantSample_list), sample_tup)
+            malignantSample_list.append(sample_tup)
+
+        if len(malignantSample_list) >= limit:
+            break
+
+    return malignantSample_list
+
+def showNodule(series_uid, batch_ndx=None, **kwargs):
+    ds = LunaDataset(series_uid=series_uid, **kwargs)
+    malignant_list = [i for i, x in enumerate(ds.sample_list) if x[2]]
+
+    if batch_ndx is None:
+        if malignant_list:
+            batch_ndx = malignant_list[0]
+        else:
+            print("Warning: no malignant samples found; using first non-malignant sample.")
+            batch_ndx = 0
+
+    ct = Ct(series_uid)
+    # ct_tensor, malignant_tensor, series_uid, center_irc = ds[batch_ndx]
+    malignant_tensor, diameter_mm, series_uid, center_irc, nodule_tensor = ds[batch_ndx]
+    ct_ary = nodule_tensor[1].numpy()
+
+
+    fig = plt.figure(figsize=(15, 25))
+
+    group_list = [
+        #[0,1,2],
+        [9,11,13],
+        [15, 16, 17],
+        [19,21,23],
+        #[12,13,14],
+        #[15]
+    ]
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 1)
+    subplot.set_title('index {}'.format(int(center_irc.index)))
+    plt.imshow(ct.ary[int(center_irc.index)], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 2)
+    subplot.set_title('row {}'.format(int(center_irc.row)))
+    plt.imshow(ct.ary[:,int(center_irc.row)], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 3)
+    subplot.set_title('col {}'.format(int(center_irc.col)))
+    plt.imshow(ct.ary[:,:,int(center_irc.col)], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 4)
+    subplot.set_title('index {}'.format(int(center_irc.index)))
+    plt.imshow(ct_ary[ct_ary.shape[0]//2], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 5)
+    subplot.set_title('row {}'.format(int(center_irc.row)))
+    plt.imshow(ct_ary[:,ct_ary.shape[1]//2], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 6)
+    subplot.set_title('col {}'.format(int(center_irc.col)))
+    plt.imshow(ct_ary[:,:,ct_ary.shape[2]//2], clim=clim, cmap='gray')
+
+    for row, index_list in enumerate(group_list):
+        for col, index in enumerate(index_list):
+            subplot = fig.add_subplot(len(group_list) + 2, 3, row * 3 + col + 7)
+            subplot.set_title('slice {}'.format(index))
+            plt.imshow(ct_ary[index*2], clim=clim, cmap='gray')
+
+
+    print(series_uid, batch_ndx, bool(malignant_tensor[0]), malignant_list, ct.vxSize_xyz)
+
+    return ct_ary

File diff suppressed because it is too large
+ 115 - 0
p2ch12_explore_data.ipynb


File diff suppressed because it is too large
+ 122 - 0
p2ch12_explore_diagnose.ipynb


+ 1 - 1
util/disk.py

@@ -80,7 +80,7 @@ class GzipDisk(Disk):
 def getCache(scope_str):
     return FanoutCache('data-unversioned/cache/' + scope_str,
                        disk=GzipDisk,
-                       shards=128,
+                       shards=64,
                        timeout=1,
                        size_limit=2e11,
                        # disk_min_file_size=2**20,

+ 4 - 4
util/unet.py

@@ -97,15 +97,15 @@ class UNetConvBlock(nn.Module):
 
         block.append(nn.Conv2d(in_size, out_size, kernel_size=3,
                                padding=int(padding)))
-        # block.append(nn.ReLU())
-        block.append(nn.LeakyReLU())
+        block.append(nn.ReLU())
+        # block.append(nn.LeakyReLU())
         if batch_norm:
             block.append(nn.BatchNorm2d(out_size))
 
         block.append(nn.Conv2d(out_size, out_size, kernel_size=3,
                                padding=int(padding)))
-        # block.append(nn.ReLU())
-        block.append(nn.LeakyReLU())
+        block.append(nn.ReLU())
+        # block.append(nn.LeakyReLU())
         if batch_norm:
             block.append(nn.BatchNorm2d(out_size))
 

+ 51 - 56
util/util.py

@@ -39,7 +39,6 @@ def irc2xyz(coord_irc, origin_xyz, vxSize_xyz, direction_tup):
     else:
         raise Exception("Unsupported direction_tup: {}".format(direction_tup))
 
-
     coord_xyz = coord_cri * direction_ary * np.array(vxSize_xyz) + np.array(origin_xyz)
     return XyzTuple(*coord_xyz.tolist())
 
@@ -153,63 +152,59 @@ def prhist(ary, prefix_str=None, **kwargs):
 
 def enumerateWithEstimate(iter, desc_str, start_ndx=0, print_ndx=4, backoff=2, iter_len=None):
     """
-
-    :param iter: `iter` is the iterable that will be passed into `enumerate`. Required.
-
-    :param desc_str: This is a human-readable string that describes what the loop is doing.
-The value is arbitrary, but should be kept reasonably short.
-Things like `"epoch 4 training"` or `"deleting temp files"` or similar
-would all make sense.
-
-    :param start_ndx:
-    :param print_ndx:
-    :param backoff:
-    :param iter_len: Since we need to know the number of items to estimate when the loop will finish,
-that can be provided by passing in a value for `iter_len`.
-If a value isn't provided, then it will be set by using the value of `len(iter)`.
+    In terms of behavior, `enumerateWithEstimate` is almost identical
+    to the standard `enumerate` (the differences are things like how
+    our function returns a generator, while `enumerate` returns a
+    specialized `<enumerate object at 0x...>`).
+
+    However, the side effects (logging, specifically) are what make the
+    function interesting.
+
+    :param iter: `iter` is the iterable that will be passed into
+        `enumerate`. Required.
+
+    :param desc_str: This is a human-readable string that describes
+        what the loop is doing. The value is arbitrary, but should be
+        kept reasonably short. Things like `"epoch 4 training"` or
+        `"deleting temp files"` or similar would all make sense.
+
+    :param start_ndx: This parameter defines how many iterations of the
+        loop should be skipped before timing actually starts. Skipping
+        a few iterations can be useful if there are startup costs like
+        caching that are only paid early on, resulting in a skewed
+        average when those early iterations dominate the average time
+        per iteration.
+
+        NOTE: Using `start_ndx` to skip some iterations makes the time
+        spent performing those iterations not be included in the
+        displayed duration. Please account for this if you use the
+        displayed duration for anything formal.
+
+        This parameter defaults to `0`.
+
+    :param print_ndx: determines which loop interation that the timing
+        logging will start on. The intent is that we don't start
+        logging until we've given the loop a few iterations to let the
+        average time-per-iteration a chance to stablize a bit. We
+        require that `print_ndx` not be less than `start_ndx` times
+        `backoff`, since `start_ndx` greater than `0` implies that the
+        early N iterations are unstable from a timing perspective.
+
+        `print_ndx` defaults to `4`.
+
+    :param backoff: This is used to how many iterations to skip before
+        logging again. Frequent logging is less interesting later on,
+        so by default we double the gap between logging messages each
+        time after the first.
+
+        `backoff` defaults to `2`.
+
+    :param iter_len: Since we need to know the number of items to
+        estimate when the loop will finish, that can be provided by
+        passing in a value for `iter_len`. If a value isn't provided,
+        then it will be set by using the value of `len(iter)`.
 
     :return:
-
-
-==== Required argument: `iter` and optionally `iter_len`
-
-These two are pretty simple.
-
-==== Required argument: `desc_str`
-
-
-
-==== Optional argument: `start_ndx`
-
-This parameter defines how many iterations of the loop should be skipped
-before timing actually starts.
-Skipping a few iterations can be useful if there are startup costs
-like caching that are only paid early on,
-resulting in a skewed average
-when those early iterations dominate the average time per iteration.
-
-NOTE: Using `start_ndx` to skip some iterations makes the time spent
-performing those iterations not be included
-in the displayed duration.
-Please account for this if you use the displayed duration for anything formal.
-
-This parameter defaults to `0`.
-
-==== Optional arguments: `print_ndx` and `backoff`
-
-`print_ndx` determines which loop interation that the timing logging will start on,
-and `backoff` is used to how many iterations to skip before logging again.
-The intent is that we don't start logging until we've given the loop
-a few iterations to let the average time-per-iteration a chance to stablize a bit.
-We require that `print_ndx` not be less than `start_ndx` times `backoff`,
-since `start_ndx` greater than `0` implies that the early N iterations
-are unstable from a timing perspective.
-Frequent logging is less interesting later on,
-so by default we double the gap between logging messages each time after the first.
-
-
-`print_ndx` defaults to `4` and `backoff` defaults to `2`.
-
     """
     if iter_len is None:
         iter_len = len(iter)

Some files were not shown because too many files changed in this diff