Browse Source

Update for latest MEAP, PyTorch v1.3.0, p2ch14.

Eli Stevens 6 years ago
parent
commit
c9ad558247
62 changed files with 7829 additions and 1761 deletions
  1. 8 9
      p1ch3/1_tensors.ipynb
  2. 1 1
      p1ch4/2_time_series_bikes.ipynb
  3. 8 8
      p1ch4/3_text_jane_austin.ipynb
  4. 5 5
      p1ch4/4_audio_chirp.ipynb
  5. 1 1
      p1ch4/5_image_dog.ipynb
  6. 2 2
      p1ch6/1_neural_networks.ipynb
  7. 1 1
      p1ch7/1_datasets.ipynb
  8. 1 1
      p1ch8/1_convolution.ipynb
  9. 321 0
      p2_run_everything.ipynb
  10. 25 23
      p2ch09/dsets.py
  11. 10 10
      p2ch09/vis.py
  12. 100 12
      p2ch09_explore_data.ipynb
  13. 32 40
      p2ch10/dsets.py
  14. 1 1
      p2ch10/model.py
  15. 66 63
      p2ch10/training.py
  16. 46 33
      p2ch10/vis.py
  17. 77 63
      p2ch10_explore_data.ipynb
  18. 7 7
      p2ch11/1_final_metric_f1_score.ipynb
  19. 39 163
      p2ch11/dsets.py
  20. 45 28
      p2ch11/model.py
  21. 1 1
      p2ch11/prepcache.py
  22. 103 152
      p2ch11/training.py
  23. 45 33
      p2ch11/vis.py
  24. 59 0
      p2ch12/1_final_metric_f1_score.ipynb
  25. 38 44
      p2ch12/diagnose.py
  26. 58 311
      p2ch12/dsets.py
  27. 53 14
      p2ch12/model.py
  28. 5 14
      p2ch12/prepcache.py
  29. 1 1
      p2ch12/screencts.py
  30. 77 69
      p2ch12/train_cls.py
  31. 186 144
      p2ch12/train_seg.py
  32. 210 348
      p2ch12/training.py
  33. 44 31
      p2ch12/vis.py
  34. 4 40
      p2ch12_explore_data.ipynb
  35. 14 23
      p2ch12_explore_diagnose.ipynb
  36. 0 0
      p2ch13/__init__.py
  37. 372 0
      p2ch13/diagnose.py
  38. 566 0
      p2ch13/dsets.py
  39. 68 0
      p2ch13/model.py
  40. 92 0
      p2ch13/model_cls.py
  41. 46 0
      p2ch13/model_seg.py
  42. 328 0
      p2ch13/model_segmentation.py
  43. 68 0
      p2ch13/prepcache.py
  44. 92 0
      p2ch13/screencts.py
  45. 459 0
      p2ch13/train_cls.py
  46. 580 0
      p2ch13/train_seg.py
  47. 702 0
      p2ch13/training.py
  48. 99 0
      p2ch13/vis.py
  49. 79 0
      p2ch13_explore_data.ipynb
  50. 113 0
      p2ch13_explore_diagnose.ipynb
  51. 0 0
      p2ch14/__init__.py
  52. 372 0
      p2ch14/diagnose.py
  53. 580 0
      p2ch14/dsets.py
  54. 92 0
      p2ch14/model_cls.py
  55. 46 0
      p2ch14/model_seg.py
  56. 68 0
      p2ch14/prepcache.py
  57. 92 0
      p2ch14/screencts.py
  58. 462 0
      p2ch14/train_cls.py
  59. 580 0
      p2ch14/train_seg.py
  60. 99 0
      p2ch14/vis.py
  61. 56 56
      util/test_affine.py
  62. 24 9
      util/util.py

+ 8 - 9
p1ch3/1_tensors.ipynb

@@ -510,7 +510,6 @@
     }
    ],
    "source": [
-    "points = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])\n",
     "second_point = points[1]\n",
     "second_point.size()"
    ]
@@ -727,9 +726,9 @@
     }
    ],
    "source": [
-    "some_tensor = torch.ones(3, 4, 5)\n",
-    "some_tensor_t = some_tensor.transpose(0, 2)\n",
-    "some_tensor.shape"
+    "some_t = torch.ones(3, 4, 5)\n",
+    "transpose_t = some_t.transpose(0, 2)\n",
+    "some_t.shape"
    ]
   },
   {
@@ -749,7 +748,7 @@
     }
    ],
    "source": [
-    "some_tensor_t.shape"
+    "transpose_t.shape"
    ]
   },
   {
@@ -769,7 +768,7 @@
     }
    ],
    "source": [
-    "some_tensor.stride()"
+    "some_t.stride()"
    ]
   },
   {
@@ -789,7 +788,7 @@
     }
    ],
    "source": [
-    "some_tensor_t.stride()"
+    "transpose_t.stride()"
    ]
   },
   {
@@ -1179,7 +1178,7 @@
    "source": [
     "f = h5py.File('../data/p1ch3/ourpoints.hdf5', 'r')\n",
     "dset = f['coords']\n",
-    "last_points = dset[1:]"
+    "last_points = dset[-2:]"
    ]
   },
   {
@@ -1188,7 +1187,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "last_points = torch.from_numpy(dset[1:])\n",
+    "last_points = torch.from_numpy(dset[-2:])\n",
     "f.close()"
    ]
   },

+ 1 - 1
p1ch4/2_time_series_bikes.ipynb

@@ -150,7 +150,7 @@
    "source": [
     "weather_onehot.scatter_(\n",
     "    dim=1, \n",
-    "    index=first_day[:,9].unsqueeze(1) - 1, # <1>\n",
+    "    index=first_day[:,9].unsqueeze(1).long() - 1, # <1>\n",
     "    value=1.0)"
    ]
   },

+ 8 - 8
p1ch4/3_text_jane_austin.ipynb

@@ -60,8 +60,8 @@
     }
    ],
    "source": [
-    "letter_tensor = torch.zeros(len(line), 128) # <1> \n",
-    "letter_tensor.shape"
+    "letter_t = torch.zeros(len(line), 128) # <1> \n",
+    "letter_t.shape"
    ]
   },
   {
@@ -72,7 +72,7 @@
    "source": [
     "for i, letter in enumerate(line.lower().strip()):\n",
     "    letter_index = ord(letter) if ord(letter) < 128 else 0  # <1>\n",
-    "    letter_tensor[i][letter_index] = 1"
+    "    letter_t[i][letter_index] = 1"
    ]
   },
   {
@@ -161,13 +161,13 @@
     }
    ],
    "source": [
-    "word_tensor = torch.zeros(len(words_in_line), len(word2index_dict))\n",
+    "word_t = torch.zeros(len(words_in_line), len(word2index_dict))\n",
     "for i, word in enumerate(words_in_line):\n",
     "    word_index = word2index_dict[word]\n",
-    "    word_tensor[i][word_index] = 1\n",
+    "    word_t[i][word_index] = 1\n",
     "    print('{:2} {:4} {}'.format(i, word_index, word))\n",
     "    \n",
-    "print(word_tensor.shape)\n"
+    "print(word_t.shape)\n"
    ]
   },
   {
@@ -187,8 +187,8 @@
     }
    ],
    "source": [
-    "word_tensor = word_tensor.unsqueeze(1)\n",
-    "word_tensor.shape"
+    "word_t = word_t.unsqueeze(1)\n",
+    "word_t.shape"
    ]
   },
   {

+ 5 - 5
p1ch4/4_audio_chirp.ipynb

@@ -167,9 +167,9 @@
    ],
    "source": [
     "sp_left = sp_right = sp_arr\n",
-    "sp_left_tensor = torch.from_numpy(sp_left)\n",
-    "sp_right_tensor = torch.from_numpy(sp_right)\n",
-    "sp_left_tensor.shape, sp_right_tensor.shape"
+    "sp_left_t = torch.from_numpy(sp_left)\n",
+    "sp_right_t = torch.from_numpy(sp_right)\n",
+    "sp_left_t.shape, sp_right_t.shape"
    ]
   },
   {
@@ -197,8 +197,8 @@
     }
    ],
    "source": [
-    "sp_tensor = torch.stack((sp_left_tensor, sp_right_tensor), dim=0)\n",
-    "sp_tensor.shape"
+    "sp_t = torch.stack((sp_left_t, sp_right_t), dim=0)\n",
+    "sp_t.shape"
    ]
   },
   {

+ 1 - 1
p1ch4/5_image_dog.ipynb

@@ -63,7 +63,7 @@
     "import os\n",
     "\n",
     "data_dir = '../data/p1ch4/image-cats/'\n",
-    "filenames = [name for name in os.listdir(data_dir) if os.path.splitext(name) == '.png']\n",
+    "filenames = [name for name in os.listdir(data_dir) if os.path.splitext(name)[-1] == '.png']\n",
     "for i, filename in enumerate(filenames):\n",
     "    img_arr = imageio.imread(filename)\n",
     "    batch[i] = torch.transpose(torch.from_numpy(img_arr), 0, 2)"

+ 2 - 2
p1ch6/1_neural_networks.ipynb

@@ -262,10 +262,10 @@
    "source": [
     "def training_loop(n_epochs, optimizer, model, loss_fn, t_u_train, t_u_val, t_c_train, t_c_val):\n",
     "    for epoch in range(1, n_epochs + 1):\n",
-    "        t_p_train = model(t_un_train) # <1>\n",
+    "        t_p_train = model(t_u_train) # <1>\n",
     "        loss_train = loss_fn(t_p_train, t_c_train)\n",
     "\n",
-    "        t_p_val = model(t_un_val) # <1>\n",
+    "        t_p_val = model(t_u_val) # <1>\n",
     "        loss_val = loss_fn(t_p_val, t_c_val)\n",
     "        \n",
     "        optimizer.zero_grad()\n",

+ 1 - 1
p1ch7/1_datasets.ipynb

@@ -42,7 +42,7 @@
    ],
    "source": [
     "from torchvision import datasets\n",
-    "data_path = '../data-unversioned/p1ch6/''\n",
+    "data_path = '../data-unversioned/p1ch6/'\n",
     "cifar10 = datasets.CIFAR10(data_path, train=True, download=True)\n",
     "cifar10_val = datasets.CIFAR10(data_path, train=False, download=True)"
    ]

File diff suppressed because it is too large
+ 1 - 1
p1ch8/1_convolution.ipynb


+ 321 - 0
p2_run_everything.ipynb

@@ -0,0 +1,321 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import datetime\n",
+    "\n",
+    "from util.util import importstr\n",
+    "from util.logconf import logging\n",
+    "log = logging.getLogger('nb')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def run(app, *argv):\n",
+    "    argv.insert(0, '--num-workers=6')  # <1>\n",
+    "    log.info(\"Running: {}({!r}).main()\".format(app, argv))\n",
+    "    \n",
+    "    app_cls = importstr(*app.rsplit('.', 1))  # <2>\n",
+    "    app_cls(argv).main()\n",
+    "    \n",
+    "    log.info(\"Finished: {}.{!r}).main()\".format(app, arg_list))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "import shutil\n",
+    "\n",
+    "# clean up any old data that might be around.\n",
+    "# We don't call this by default because it's destructive, \n",
+    "# and would waste a lot of time if it ran when nothing \n",
+    "# on the application side had changed.\n",
+    "def cleanCache():\n",
+    "    shutil.rmtree('data-unversioned/cache')\n",
+    "    os.mkdir('data-unversioned/cache')\n",
+    "\n",
+    "cleanCache()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "training_epochs = 20\n",
+    "experiment_epochs = 10\n",
+    "final_epochs = 50\n",
+    "\n",
+    "training_epochs = 2\n",
+    "experiment_epochs = 2\n",
+    "final_epochs = 5\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Chapter 11"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "2019-08-18 13:26:12,563 INFO     pid:6276 nb:008:run Running: p2ch10.prepcache.LunaPrepCacheApp(None).main()\n",
+      "2019-08-18 13:26:12,564 INFO     pid:6276 p2ch10.prepcache:043:main Starting LunaPrepCacheApp, Namespace(batch_size=1024, num_workers=8)\n",
+      "2019-08-18 13:26:16,158 INFO     pid:6276 p2ch10.dsets:165:__init__ <p2ch10.dsets.LunaDataset object at 0x000001EB37328E48>: 551065 training samples\n",
+      "2019-08-18 13:26:16,159 WARNING  pid:6276 util.util:248:enumerateWithEstimate Stuffing cache ----/539, starting\n",
+      "2019-08-18 13:27:27,582 INFO     pid:6276 util.util:272:enumerateWithEstimate Stuffing cache   16/539, done at 2019-08-18 13:52:18, 0:25:16\n",
+      "2019-08-18 13:28:10,351 INFO     pid:6276 util.util:272:enumerateWithEstimate Stuffing cache   32/539, done at 2019-08-18 13:51:16, 0:24:14\n",
+      "2019-08-18 13:29:34,831 INFO     pid:6276 util.util:272:enumerateWithEstimate Stuffing cache   64/539, done at 2019-08-18 13:50:46, 0:23:44\n",
+      "2019-08-18 13:32:27,482 INFO     pid:6276 util.util:272:enumerateWithEstimate Stuffing cache  128/539, done at 2019-08-18 13:50:50, 0:23:48\n",
+      "2019-08-18 13:38:55,698 INFO     pid:6276 util.util:272:enumerateWithEstimate Stuffing cache  256/539, done at 2019-08-18 13:52:24, 0:25:22\n",
+      "2019-08-18 13:53:20,526 INFO     pid:6276 util.util:272:enumerateWithEstimate Stuffing cache  512/539, done at 2019-08-18 13:54:41, 0:27:39\n",
+      "2019-08-18 13:54:31,460 WARNING  pid:6276 util.util:285:enumerateWithEstimate Stuffing cache ----/539, done at 2019-08-18 13:54:31\n",
+      "2019-08-18 13:54:31,472 INFO     pid:6276 nb:013:run Finished: p2ch10.prepcache.LunaPrepCacheApp.None).main()\n"
+     ]
+    }
+   ],
+   "source": [
+    "run('p2ch11.prepcache.LunaPrepCacheApp')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "2019-08-18 14:03:26,689 INFO     pid:6276 nb:008:run Running: p2ch10.training.LunaTrainingApp(['--epochs=1']).main()\n",
+      "2019-08-18 14:03:29,660 INFO     pid:6276 p2ch10.training:155:main Starting LunaTrainingApp, Namespace(batch_size=32, comment='none', epochs=1, num_workers=8, tb_prefix='p2ch10')\n",
+      "2019-08-18 14:03:30,029 INFO     pid:6276 p2ch10.dsets:165:__init__ <p2ch10.dsets.LunaDataset object at 0x000001EB3B67D160>: 495958 training samples\n",
+      "2019-08-18 14:03:30,079 INFO     pid:6276 p2ch10.dsets:165:__init__ <p2ch10.dsets.LunaDataset object at 0x000001EB3E44E320>: 55107 validation samples\n",
+      "2019-08-18 14:03:30,088 INFO     pid:6276 p2ch10.training:181:main Epoch 1 of 1, 15499/1723 batches of size 32*1\n",
+      "2019-08-18 14:03:30,093 WARNING  pid:6276 util.util:248:enumerateWithEstimate E1 Training ----/15499, starting\n",
+      "2019-08-18 14:03:55,834 INFO     pid:6276 util.util:272:enumerateWithEstimate E1 Training   16/15499, done at 2019-08-18 14:19:48, 0:15:53\n",
+      "2019-08-18 14:03:56,800 INFO     pid:6276 util.util:272:enumerateWithEstimate E1 Training   32/15499, done at 2019-08-18 14:19:37, 0:15:42\n",
+      "2019-08-18 14:03:58,747 INFO     pid:6276 util.util:272:enumerateWithEstimate E1 Training   64/15499, done at 2019-08-18 14:19:37, 0:15:42\n",
+      "2019-08-18 14:04:02,615 INFO     pid:6276 util.util:272:enumerateWithEstimate E1 Training  128/15499, done at 2019-08-18 14:19:34, 0:15:38\n",
+      "2019-08-18 14:04:10,374 INFO     pid:6276 util.util:272:enumerateWithEstimate E1 Training  256/15499, done at 2019-08-18 14:19:34, 0:15:39\n",
+      "2019-08-18 14:04:25,850 INFO     pid:6276 util.util:272:enumerateWithEstimate E1 Training  512/15499, done at 2019-08-18 14:19:33, 0:15:37\n",
+      "2019-08-18 14:04:56,935 INFO     pid:6276 util.util:272:enumerateWithEstimate E1 Training 1024/15499, done at 2019-08-18 14:19:34, 0:15:39\n",
+      "2019-08-18 14:05:59,204 INFO     pid:6276 util.util:272:enumerateWithEstimate E1 Training 2048/15499, done at 2019-08-18 14:19:35, 0:15:40\n",
+      "2019-08-18 14:08:03,762 INFO     pid:6276 util.util:272:enumerateWithEstimate E1 Training 4096/15499, done at 2019-08-18 14:19:36, 0:15:41\n",
+      "2019-08-18 14:12:14,139 INFO     pid:6276 util.util:272:enumerateWithEstimate E1 Training 8192/15499, done at 2019-08-18 14:19:39, 0:15:44\n",
+      "2019-08-18 14:19:41,875 WARNING  pid:6276 util.util:285:enumerateWithEstimate E1 Training ----/15499, done at 2019-08-18 14:19:41\n",
+      "2019-08-18 14:19:41,880 INFO     pid:6276 p2ch10.training:291:logMetrics E1 LunaTrainingApp\n",
+      "2019-08-18 14:19:41,900 INFO     pid:6276 p2ch10.training:344:logMetrics E1 trn      0.0192 loss,  99.8% correct, \n",
+      "2019-08-18 14:19:41,901 INFO     pid:6276 p2ch10.training:356:logMetrics E1 trn_ben  0.0026 loss, 100.0% correct (494735 of 494743)\n",
+      "2019-08-18 14:19:41,901 INFO     pid:6276 p2ch10.training:369:logMetrics E1 trn_mal  6.7756 loss,   0.0% correct (0 of 1215)\n",
+      "2019-08-18 14:19:41,975 WARNING  pid:6276 util.util:248:enumerateWithEstimate E1 Validation  ----/1723, starting\n",
+      "2019-08-18 14:19:52,708 INFO     pid:6276 util.util:272:enumerateWithEstimate E1 Validation    16/1723, done at 2019-08-18 14:21:15, 0:01:23\n",
+      "2019-08-18 14:19:53,160 INFO     pid:6276 util.util:272:enumerateWithEstimate E1 Validation    32/1723, done at 2019-08-18 14:20:53, 0:01:00\n",
+      "2019-08-18 14:19:54,134 INFO     pid:6276 util.util:272:enumerateWithEstimate E1 Validation    64/1723, done at 2019-08-18 14:20:48, 0:00:56\n",
+      "2019-08-18 14:19:56,367 INFO     pid:6276 util.util:272:enumerateWithEstimate E1 Validation   128/1723, done at 2019-08-18 14:20:50, 0:00:58\n",
+      "2019-08-18 14:20:00,518 INFO     pid:6276 util.util:272:enumerateWithEstimate E1 Validation   256/1723, done at 2019-08-18 14:20:49, 0:00:56\n",
+      "2019-08-18 14:20:08,278 INFO     pid:6276 util.util:272:enumerateWithEstimate E1 Validation   512/1723, done at 2019-08-18 14:20:46, 0:00:54\n",
+      "2019-08-18 14:20:24,934 INFO     pid:6276 util.util:272:enumerateWithEstimate E1 Validation  1024/1723, done at 2019-08-18 14:20:47, 0:00:55\n",
+      "2019-08-18 14:20:46,957 WARNING  pid:6276 util.util:285:enumerateWithEstimate E1 Validation  ----/1723, done at 2019-08-18 14:20:46\n",
+      "2019-08-18 14:20:46,959 INFO     pid:6276 p2ch10.training:291:logMetrics E1 LunaTrainingApp\n",
+      "2019-08-18 14:20:46,961 INFO     pid:6276 p2ch10.training:344:logMetrics E1 val      0.0176 loss,  99.8% correct, \n",
+      "2019-08-18 14:20:46,961 INFO     pid:6276 p2ch10.training:356:logMetrics E1 val_ben  0.0014 loss, 100.0% correct (54971 of 54971)\n",
+      "2019-08-18 14:20:46,962 INFO     pid:6276 p2ch10.training:369:logMetrics E1 val_mal  6.5590 loss,   0.0% correct (0 of 136)\n"
+     ]
+    }
+   ],
+   "source": [
+    "run('p2ch11.training.LunaTrainingApp', '--epochs=1')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "run('p2ch11.training.LunaTrainingApp', f'--epochs={experiment_epochs}')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Chapter 12"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "run('p2ch12.prepcache.LunaPrepCacheApp')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "run('p2ch12.training.LunaTrainingApp', '--epochs=1', 'unbalanced')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "run('p2ch12.training.LunaTrainingApp', f'--epochs={training_epochs}', '--balanced', 'balanced')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "run('p2ch12.training.LunaTrainingApp', f'--epochs={experiment_epochs}', '--balanced', '--augment-flip', 'flip')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "run('p2ch12.training.LunaTrainingApp', f'--epochs={experiment_epochs}', '--balanced', '--augment-offset', 'offset')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "run('p2ch12.training.LunaTrainingApp', f'--epochs={experiment_epochs}', '--balanced', '--augment-scale', 'scale')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "run('p2ch12.training.LunaTrainingApp', f'--epochs={experiment_epochs}', '--balanced', '--augment-rotate', 'rotate')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "run('p2ch12.training.LunaTrainingApp', f'--epochs={experiment_epochs}', '--balanced', '--augment-noise', 'noise')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "run('p2ch12.training.LunaTrainingApp', f'--epochs={training_epochs}', '--balanced', '--augmented', 'fully-augmented')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Chapter 13"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "run('p2ch13.prepcache.LunaPrepCacheApp')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "run('p2ch13.train_cls.LunaTrainingApp', f'--epochs={final_epochs}', '--balanced', '--augmented', 'final-cls')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "run('p2ch13.train_seg.LunaTrainingApp', f'--epochs={training_epochs}', '--balanced', '--augmented', 'final-seg')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Chapter 14"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "run('p2ch14.diagnose.LunaDiagnoseApp')"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.6"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}

+ 25 - 23
p2ch09/dsets.py

@@ -75,18 +75,18 @@ class Ct(object):
         mhd_path = glob.glob('data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
 
         ct_mhd = sitk.ReadImage(mhd_path)
-        ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
+        ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
 
         # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
         # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
         # This gets rid of negative density stuff used to indicate out-of-FOV
-        ct_ary[ct_ary < -1000] = -1000
+        ct_a[ct_a < -1000] = -1000
 
         # This nukes any weird hotspots and clamps bone down
-        ct_ary[ct_ary > 1000] = 1000
+        ct_a[ct_a > 1000] = 1000
 
         self.series_uid = series_uid
-        self.ary = ct_ary
+        self.hu_a = ct_a
 
         self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
         self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
@@ -100,23 +100,23 @@ class Ct(object):
             start_ndx = int(round(center_val - width_irc[axis]/2))
             end_ndx = int(start_ndx + width_irc[axis])
 
-            assert center_val >= 0 and center_val < self.ary.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
+            assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
 
             if start_ndx < 0:
                 # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
-                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
+                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
                 start_ndx = 0
                 end_ndx = int(width_irc[axis])
 
-            if end_ndx > self.ary.shape[axis]:
+            if end_ndx > self.hu_a.shape[axis]:
                 # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
-                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
-                end_ndx = self.ary.shape[axis]
-                start_ndx = int(self.ary.shape[axis] - width_irc[axis])
+                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
+                end_ndx = self.hu_a.shape[axis]
+                start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])
 
             slice_list.append(slice(start_ndx, end_ndx))
 
-        ct_chunk = self.ary[tuple(slice_list)]
+        ct_chunk = self.hu_a[tuple(slice_list)]
 
         return ct_chunk, center_irc
 
@@ -133,8 +133,8 @@ def getCtRawNodule(series_uid, center_xyz, width_irc):
 
 class LunaDataset(Dataset):
     def __init__(self,
-                 test_stride=0,
-                 isTestSet_bool=None,
+                 val_stride=0,
+                 isValSet_bool=None,
                  series_uid=None,
             ):
         self.noduleInfo_list = copy.copy(getNoduleInfoList())
@@ -142,16 +142,16 @@ class LunaDataset(Dataset):
         if series_uid:
             self.noduleInfo_list = [x for x in self.noduleInfo_list if x[2] == series_uid]
 
-        if test_stride > 1:
-            if isTestSet_bool:
-                self.noduleInfo_list = self.noduleInfo_list[::test_stride]
+        if val_stride > 1:
+            if isValSet_bool:
+                self.noduleInfo_list = self.noduleInfo_list[::val_stride]
             else:
-                del self.noduleInfo_list[::test_stride]
+                del self.noduleInfo_list[::val_stride]
 
         log.info("{!r}: {} {} samples".format(
             self,
             len(self.noduleInfo_list),
-            "testing" if isTestSet_bool else "training",
+            "validation" if isValSet_bool else "training",
         ))
 
     def __len__(self):
@@ -161,19 +161,21 @@ class LunaDataset(Dataset):
         nodule_tup = self.noduleInfo_list[ndx]
         width_irc = (24, 48, 48)
 
-        nodule_ary, center_irc = getCtRawNodule(
+        nodule_a, center_irc = getCtRawNodule(
             nodule_tup.series_uid,
             nodule_tup.center_xyz,
             width_irc,
         )
-        nodule_tensor = torch.from_numpy(nodule_ary).to(torch.float32)
-        nodule_tensor = nodule_tensor.unsqueeze(0)
 
-        cls_tensor = torch.tensor([
+        nodule_t = torch.from_numpy(nodule_a)
+        nodule_t = nodule_t.to(torch.float32)
+        nodule_t = nodule_t.unsqueeze(0)
+
+        cls_t = torch.tensor([
                 not nodule_tup.isMalignant_bool,
                 nodule_tup.isMalignant_bool
             ],
             dtype=torch.long,
         )
 
-        return nodule_tensor, cls_tensor, nodule_tup.series_uid, center_irc
+        return nodule_t, cls_t, nodule_tup.series_uid, center_irc

+ 10 - 10
p2ch09/vis.py

@@ -31,8 +31,8 @@ def showNodule(series_uid, batch_ndx=None):
             batch_ndx = 0
 
     ct = Ct(series_uid)
-    ct_tensor, malignant_tensor, series_uid, center_irc = ds[batch_ndx]
-    ct_ary = ct_tensor[0].numpy()
+    ct_t, malignant_t, series_uid, center_irc = ds[batch_ndx]
+    ct_a = ct_t[0].numpy()
 
     fig = plt.figure(figsize=(15, 25))
 
@@ -44,35 +44,35 @@ def showNodule(series_uid, batch_ndx=None):
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 1)
     subplot.set_title('index {}'.format(int(center_irc.index)))
-    plt.imshow(ct.ary[int(center_irc.index)], clim=clim, cmap='gray')
+    plt.imshow(ct.hu_a[int(center_irc.index)], clim=clim, cmap='gray')
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 2)
     subplot.set_title('row {}'.format(int(center_irc.row)))
-    plt.imshow(ct.ary[:,int(center_irc.row)], clim=clim, cmap='gray')
+    plt.imshow(ct.hu_a[:,int(center_irc.row)], clim=clim, cmap='gray')
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 3)
     subplot.set_title('col {}'.format(int(center_irc.col)))
-    plt.imshow(ct.ary[:,:,int(center_irc.col)], clim=clim, cmap='gray')
+    plt.imshow(ct.hu_a[:,:,int(center_irc.col)], clim=clim, cmap='gray')
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 4)
     subplot.set_title('index {}'.format(int(center_irc.index)))
-    plt.imshow(ct_ary[ct_ary.shape[0]//2], clim=clim, cmap='gray')
+    plt.imshow(ct_a[ct_a.shape[0]//2], clim=clim, cmap='gray')
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 5)
     subplot.set_title('row {}'.format(int(center_irc.row)))
-    plt.imshow(ct_ary[:,ct_ary.shape[1]//2], clim=clim, cmap='gray')
+    plt.imshow(ct_a[:,ct_a.shape[1]//2], clim=clim, cmap='gray')
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 6)
     subplot.set_title('col {}'.format(int(center_irc.col)))
-    plt.imshow(ct_ary[:,:,ct_ary.shape[2]//2], clim=clim, cmap='gray')
+    plt.imshow(ct_a[:,:,ct_a.shape[2]//2], clim=clim, cmap='gray')
 
     for row, index_list in enumerate(group_list):
         for col, index in enumerate(index_list):
             subplot = fig.add_subplot(len(group_list) + 2, 3, row * 3 + col + 7)
             subplot.set_title('slice {}'.format(index))
-            plt.imshow(ct_ary[index], clim=clim, cmap='gray')
+            plt.imshow(ct_a[index], clim=clim, cmap='gray')
 
 
-    print(series_uid, batch_ndx, bool(malignant_tensor[0]), malignant_list)
+    print(series_uid, batch_ndx, bool(malignant_t[0]), malignant_list)
 
 

+ 100 - 12
p2ch09_explore_data.ipynb

@@ -16,7 +16,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "from p2ch09.dsets import getNoduleInfoList, getCt\n",
+    "from p2ch09.dsets import getNoduleInfoList, getCt, LunaDataset\n",
     "noduleInfo_list = getNoduleInfoList(requireDataOnDisk_bool=False)\n",
     "malignantInfo_list = [x for x in noduleInfo_list if x[0]]\n",
     "diameter_list = [x[1] for x in malignantInfo_list]"
@@ -143,7 +143,7 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "2019-06-09 08:43:27,092 INFO     pid:21956 p2ch09.dsets:202:__init__ <p2ch09.dsets.LunaDataset object at 0x000001E9BE629A90>: 551065 training samples\n"
+      "2019-08-09 20:17:20,741 INFO     pid:19236 p2ch09.dsets:195:__init__ <p2ch09.dsets.LunaDataset object at 0x000001C761FB19E8>: 551065 training samples\n"
      ]
     }
    ],
@@ -161,7 +161,7 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "2019-06-09 08:43:27,152 INFO     pid:21956 p2ch09.dsets:202:__init__ <p2ch09.dsets.LunaDataset object at 0x000001E9BE858860>: 602 training samples\n"
+      "2019-08-09 20:17:20,851 INFO     pid:19236 p2ch09.dsets:195:__init__ <p2ch09.dsets.LunaDataset object at 0x000001C7630C7E10>: 602 training samples\n"
      ]
     },
     {
@@ -198,7 +198,7 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "2019-06-09 08:43:28,714 INFO     pid:21956 p2ch09.dsets:202:__init__ <p2ch09.dsets.LunaDataset object at 0x000001E9C39119B0>: 605 training samples\n"
+      "2019-08-09 20:17:22,701 INFO     pid:19236 p2ch09.dsets:195:__init__ <p2ch09.dsets.LunaDataset object at 0x000001C76585A9B0>: 605 training samples\n"
      ]
     },
     {
@@ -230,6 +230,84 @@
    "cell_type": "code",
    "execution_count": 10,
    "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "2019-08-09 20:17:25,400 INFO     pid:19236 p2ch09.dsets:195:__init__ <p2ch09.dsets.LunaDataset object at 0x000001C76BFEFC88>: 551065 training samples\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "(tensor([[[[-899., -903., -825.,  ..., -901., -898., -893.],\n",
+       "           [-892., -883., -794.,  ..., -894., -888., -890.],\n",
+       "           [-885., -854., -720.,  ..., -855., -844., -881.],\n",
+       "           ...,\n",
+       "           [-754., -674., -747.,  ..., -147.,  -97.,  -78.],\n",
+       "           [-663., -642., -771.,  ...,  -96.,  -77.,  -59.],\n",
+       "           [-627., -619., -771.,  ...,  -46.,  -64.,  -84.]],\n",
+       " \n",
+       "          [[-861., -836., -796.,  ..., -861., -839., -828.],\n",
+       "           [-902., -869., -855.,  ..., -862., -851., -832.],\n",
+       "           [-882., -845., -846.,  ..., -877., -864., -865.],\n",
+       "           ...,\n",
+       "           [-891., -893., -885.,  ..., -211., -113.,  -96.],\n",
+       "           [-898., -913., -902.,  ...,  -85.,  -82.,  -81.],\n",
+       "           [-834., -861., -844.,  ..., -107., -107.,  -80.]],\n",
+       " \n",
+       "          [[-770., -847., -893.,  ..., -819., -824., -845.],\n",
+       "           [-843., -876., -902.,  ..., -801., -794., -804.],\n",
+       "           [-880., -861., -853.,  ..., -840., -845., -837.],\n",
+       "           ...,\n",
+       "           [-842., -872., -879.,  ..., -242.,  -95.,  -41.],\n",
+       "           [-916., -929., -941.,  ..., -140., -106., -118.],\n",
+       "           [-892., -863., -860.,  ...,  -97., -100.,  -85.]],\n",
+       " \n",
+       "          ...,\n",
+       " \n",
+       "          [[-891., -902., -905.,  ..., -728., -771., -792.],\n",
+       "           [-905., -929., -909.,  ..., -780., -798., -756.],\n",
+       "           [-928., -940., -942.,  ..., -757., -749., -683.],\n",
+       "           ...,\n",
+       "           [-505., -450., -415.,  ...,   21.,   34.,   58.],\n",
+       "           [-320., -246., -219.,  ...,  102.,  144.,  122.],\n",
+       "           [-217., -192., -167.,  ...,   23.,   19.,   -2.]],\n",
+       " \n",
+       "          [[-903., -924., -945.,  ..., -734., -640., -478.],\n",
+       "           [-892., -930., -931.,  ..., -649., -544., -381.],\n",
+       "           [-786., -878., -917.,  ..., -594., -443., -239.],\n",
+       "           ...,\n",
+       "           [-134., -125., -125.,  ...,   39.,   30.,   48.],\n",
+       "           [-143., -132., -115.,  ...,   53.,   50.,   69.],\n",
+       "           [-146., -143., -135.,  ...,   35.,   63.,   76.]],\n",
+       " \n",
+       "          [[-897., -932., -950.,  ..., -565., -307., -135.],\n",
+       "           [-912., -913., -925.,  ..., -395., -171.,  -78.],\n",
+       "           [-939., -931., -950.,  ..., -219.,  -81.,  -37.],\n",
+       "           ...,\n",
+       "           [ -98.,  -75.,  -72.,  ...,   44.,   25.,   32.],\n",
+       "           [ -82.,  -81.,  -51.,  ...,   62.,   52.,   71.],\n",
+       "           [ -92.,  -63.,    4.,  ...,   63.,   70.,   52.]]]]),\n",
+       " tensor([0, 1]),\n",
+       " '1.3.6.1.4.1.14519.5.2.1.6279.6001.287966244644280690737019247886',\n",
+       " IrcTuple(index=91.27662336, row=359.6500497243522, col=341.1074155674705))"
+      ]
+     },
+     "execution_count": 10,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "LunaDataset()[0]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
    "outputs": [
     {
      "name": "stderr",
@@ -242,7 +320,7 @@
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "d50336a41d1c4d6fa11790858ed19075",
+       "model_id": "afaac78e7a024c26a2b5cb4013f450ec",
        "version_major": 2,
        "version_minor": 0
       },
@@ -269,13 +347,23 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 13,
+   "execution_count": 12,
    "metadata": {},
    "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "C:\\Users\\elis\\Miniconda3\\envs\\book\\lib\\site-packages\\ipyvolume\\widgets.py:179: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.\n",
+      "  data_view = self.data_original[view]\n",
+      "C:\\Users\\elis\\Miniconda3\\envs\\book\\lib\\site-packages\\ipyvolume\\utils.py:204: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.\n",
+      "  data = (data[slices1] + data[slices2])/2\n"
+     ]
+    },
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "5fe29bf5359943f0bce86401ee46d29c",
+       "model_id": "ce17ca22f25c4f08a0511b80d0629ddf",
        "version_major": 2,
        "version_minor": 0
       },
@@ -289,12 +377,12 @@
    ],
    "source": [
     "ct = getCt(series_uid)\n",
-    "ipv.quickvolshow(ct.ary, level=[0.25, 0.5, 0.9], opacity=0.1, level_width=0.1, data_min=-1000, data_max=1000)"
+    "ipv.quickvolshow(ct.hu_a, level=[0.25, 0.5, 0.9], opacity=0.1, level_width=0.1, data_min=-1000, data_max=1000)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 12,
+   "execution_count": 13,
    "metadata": {},
    "outputs": [
     {
@@ -304,7 +392,7 @@
      "traceback": [
       "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
       "\u001b[1;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
-      "\u001b[1;32m<ipython-input-12-2630724d96f5>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mp2ch10\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdsets\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mgetCt\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      2\u001b[0m \u001b[0mct\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mgetCt\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mseries_uid\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 3\u001b[1;33m \u001b[0mair_mask\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlung_mask\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdense_mask\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdenoise_mask\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtissue_mask\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbody_mask\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mct\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbuild3dLungMask\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
+      "\u001b[1;32m<ipython-input-13-2630724d96f5>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mp2ch10\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdsets\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mgetCt\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      2\u001b[0m \u001b[0mct\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mgetCt\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mseries_uid\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 3\u001b[1;33m \u001b[0mair_mask\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlung_mask\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdense_mask\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdenoise_mask\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtissue_mask\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbody_mask\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mct\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbuild3dLungMask\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
       "\u001b[1;31mAttributeError\u001b[0m: 'Ct' object has no attribute 'build3dLungMask'"
      ]
     }
@@ -321,8 +409,8 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "bones = ct.ary * (ct.ary > 1.5)\n",
-    "lungs = ct.ary * air_mask\n",
+    "bones = ct.hu_a * (ct.hu_a > 1.5)\n",
+    "lungs = ct.hu_a * air_mask\n",
     "ipv.figure()\n",
     "ipv.pylab.volshow(bones + lungs, level=[0.17, 0.17, 0.23], data_min=0.1, data_max=0.9)\n",
     "ipv.show()"

+ 32 - 40
p2ch10/dsets.py

@@ -3,13 +3,12 @@ import csv
 import functools
 import glob
 import os
-import random
 
 from collections import namedtuple
 
 import SimpleITK as sitk
-
 import numpy as np
+
 import torch
 import torch.cuda
 from torch.utils.data import Dataset
@@ -75,18 +74,18 @@ class Ct(object):
         mhd_path = glob.glob('data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
 
         ct_mhd = sitk.ReadImage(mhd_path)
-        ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
+        ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
 
         # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
         # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
         # This gets rid of negative density stuff used to indicate out-of-FOV
-        ct_ary[ct_ary < -1000] = -1000
+        ct_a[ct_a < -1000] = -1000
 
         # This nukes any weird hotspots and clamps bone down
-        ct_ary[ct_ary > 1000] = 1000
+        ct_a[ct_a > 1000] = 1000
 
         self.series_uid = series_uid
-        self.ary = ct_ary
+        self.hu_a = ct_a
 
         self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
         self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
@@ -100,23 +99,23 @@ class Ct(object):
             start_ndx = int(round(center_val - width_irc[axis]/2))
             end_ndx = int(start_ndx + width_irc[axis])
 
-            assert center_val >= 0 and center_val < self.ary.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
+            assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
 
             if start_ndx < 0:
                 # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
-                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
+                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
                 start_ndx = 0
                 end_ndx = int(width_irc[axis])
 
-            if end_ndx > self.ary.shape[axis]:
+            if end_ndx > self.hu_a.shape[axis]:
                 # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
-                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
-                end_ndx = self.ary.shape[axis]
-                start_ndx = int(self.ary.shape[axis] - width_irc[axis])
+                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
+                end_ndx = self.hu_a.shape[axis]
+                start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])
 
             slice_list.append(slice(start_ndx, end_ndx))
 
-        ct_chunk = self.ary[tuple(slice_list)]
+        ct_chunk = self.hu_a[tuple(slice_list)]
 
         return ct_chunk, center_irc
 
@@ -131,38 +130,29 @@ def getCtRawNodule(series_uid, center_xyz, width_irc):
     ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
     return ct_chunk, center_irc
 
-
 class LunaDataset(Dataset):
     def __init__(self,
-                 test_stride=0,
-                 isTestSet_bool=None,
+                 val_stride=0,
+                 isValSet_bool=None,
                  series_uid=None,
-                 sortby_str='random',
             ):
         self.noduleInfo_list = copy.copy(getNoduleInfoList())
 
         if series_uid:
-            self.noduleInfo_list = [x for x in self.noduleInfo_list if x[2] == series_uid]
-
-        if test_stride > 1:
-            if isTestSet_bool:
-                self.noduleInfo_list = self.noduleInfo_list[::test_stride]
-            else:
-                del self.noduleInfo_list[::test_stride]
-
-        if sortby_str == 'random':
-            random.shuffle(self.noduleInfo_list)
-        elif sortby_str == 'series_uid':
-            self.noduleInfo_list.sort(key=lambda x: (x[2], x[3])) # sorting by series_uid, center_xyz)
-        elif sortby_str == 'malignancy_size':
-            pass
-        else:
-            raise Exception("Unknown sort: " + repr(sortby_str))
+            self.noduleInfo_list = [x for x in self.noduleInfo_list if x.series_uid == series_uid]
+
+        if isValSet_bool:
+            assert val_stride > 0, val_stride
+            self.noduleInfo_list = self.noduleInfo_list[::val_stride]
+            assert self.noduleInfo_list
+        elif val_stride > 0:
+            del self.noduleInfo_list[::val_stride]
+            assert self.noduleInfo_list
 
         log.info("{!r}: {} {} samples".format(
             self,
             len(self.noduleInfo_list),
-            "testing" if isTestSet_bool else "training",
+            "validation" if isValSet_bool else "training",
         ))
 
     def __len__(self):
@@ -170,21 +160,23 @@ class LunaDataset(Dataset):
 
     def __getitem__(self, ndx):
         nodule_tup = self.noduleInfo_list[ndx]
-        width_irc = (24, 48, 48)
+        width_irc = (32, 48, 48)
 
-        nodule_ary, center_irc = getCtRawNodule(
+        nodule_a, center_irc = getCtRawNodule(
             nodule_tup.series_uid,
             nodule_tup.center_xyz,
             width_irc,
         )
-        nodule_tensor = torch.from_numpy(nodule_ary).to(torch.float32)
-        nodule_tensor = nodule_tensor.unsqueeze(0)
 
-        cls_tensor = torch.tensor([
+        nodule_t = torch.from_numpy(nodule_a)
+        nodule_t = nodule_t.to(torch.float32)
+        nodule_t = nodule_t.unsqueeze(0)
+
+        malignant_t = torch.tensor([
                 not nodule_tup.isMalignant_bool,
                 nodule_tup.isMalignant_bool
             ],
             dtype=torch.long,
         )
 
-        return nodule_tensor, cls_tensor, nodule_tup.series_uid, center_irc
+        return nodule_t, malignant_t, nodule_tup.series_uid, torch.tensor(center_irc)

+ 1 - 1
p2ch10/model.py

@@ -15,7 +15,7 @@ class LunaModel(nn.Module):
     def __init__(self, layer_count=4, in_channels=1, conv_channels=8):
         super().__init__()
 
-        self.input_batchnorm = nn.BatchNorm2d(1)
+        self.input_batchnorm = nn.BatchNorm3d(1)
 
         layer_list = []
         for layer_ndx in range(layer_count):

+ 66 - 63
p2ch10/training.py

@@ -5,7 +5,7 @@ import sys
 
 import numpy as np
 
-from tensorboardX import SummaryWriter
+from torch.utils.tensorboard import SummaryWriter
 
 import torch
 import torch.nn as nn
@@ -22,7 +22,7 @@ log = logging.getLogger(__name__)
 log.setLevel(logging.INFO)
 # log.setLevel(logging.DEBUG)
 
-# Used for computeBatchLoss and logMetrics to index into metrics_tensor/metrics_ary
+# Used for computeBatchLoss and logMetrics to index into metrics_t/metrics_a
 METRICS_LABEL_NDX=0
 METRICS_PRED_NDX=1
 METRICS_LOSS_NDX=2
@@ -58,14 +58,14 @@ class LunaTrainingApp(object):
         parser.add_argument('comment',
             help="Comment suffix for Tensorboard run.",
             nargs='?',
-            default='none',
+            default='dwlpt',
         )
 
         self.cli_args = parser.parse_args(sys_argv)
         self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
 
         self.trn_writer = None
-        self.tst_writer = None
+        self.val_writer = None
         self.totalTrainingSamples_count = 0
 
         self.use_cuda = torch.cuda.is_available()
@@ -77,6 +77,7 @@ class LunaTrainingApp(object):
     def initModel(self):
         model = LunaModel()
         if self.use_cuda:
+            log.info("Using CUDA with {} devices.".format(torch.cuda.device_count()))
             if torch.cuda.device_count() > 1:
                 model = nn.DataParallel(model)
             model = model.to(self.device)
@@ -88,8 +89,8 @@ class LunaTrainingApp(object):
 
     def initTrainDl(self):
         train_ds = LunaDataset(
-            test_stride=10,
-            isTestSet_bool=False,
+            val_stride=10,
+            isValSet_bool=False,
         )
 
         train_dl = DataLoader(
@@ -101,35 +102,34 @@ class LunaTrainingApp(object):
 
         return train_dl
 
-    def initTestDl(self):
-        test_ds = LunaDataset(
-            test_stride=10,
-            isTestSet_bool=True,
+    def initValDl(self):
+        val_ds = LunaDataset(
+            val_stride=10,
+            isValSet_bool=True,
         )
 
-        test_dl = DataLoader(
-            test_ds,
+        val_dl = DataLoader(
+            val_ds,
             batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
             num_workers=self.cli_args.num_workers,
             pin_memory=self.use_cuda,
         )
 
-        return test_dl
+        return val_dl
 
     def initTensorboardWriters(self):
         if self.trn_writer is None:
             log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
 
-            self.trn_writer = SummaryWriter(log_dir=log_dir + '_trn_cls_' + self.cli_args.comment)
-            self.tst_writer = SummaryWriter(log_dir=log_dir + '_tst_cls_' + self.cli_args.comment)
-# eng::tb_writer[]
+            self.trn_writer = SummaryWriter(log_dir=log_dir + '-trn_cls-' + self.cli_args.comment)
+            self.val_writer = SummaryWriter(log_dir=log_dir + '-val_cls-' + self.cli_args.comment)
 
 
     def main(self):
         log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
 
         train_dl = self.initTrainDl()
-        test_dl = self.initTestDl()
+        val_dl = self.initValDl()
 
         self.initTensorboardWriters()
         # self.logModelMetrics(self.model)
@@ -142,25 +142,25 @@ class LunaTrainingApp(object):
                 epoch_ndx,
                 self.cli_args.epochs,
                 len(train_dl),
-                len(test_dl),
+                len(val_dl),
                 self.cli_args.batch_size,
                 (torch.cuda.device_count() if self.use_cuda else 1),
             ))
 
-            trnMetrics_tensor = self.doTraining(epoch_ndx, train_dl)
-            self.logMetrics(epoch_ndx, 'trn', trnMetrics_tensor)
+            trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
+            self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)
 
-            tstMetrics_tensor = self.doTesting(epoch_ndx, test_dl)
-            self.logMetrics(epoch_ndx, 'tst', tstMetrics_tensor)
+            valMetrics_t = self.doValidation(epoch_ndx, val_dl)
+            self.logMetrics(epoch_ndx, 'val', valMetrics_t)
 
         if hasattr(self, 'trn_writer'):
             self.trn_writer.close()
-            self.tst_writer.close()
+            self.val_writer.close()
 
 
     def doTraining(self, epoch_ndx, train_dl):
         self.model.train()
-        trainingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(train_dl.dataset)).to(self.device)
+        trnMetrics_g = torch.zeros(METRICS_SIZE, len(train_dl.dataset)).to(self.device)
         batch_iter = enumerateWithEstimate(
             train_dl,
             "E{} Training".format(epoch_ndx),
@@ -173,70 +173,73 @@ class LunaTrainingApp(object):
                 batch_ndx,
                 batch_tup,
                 train_dl.batch_size,
-                trainingMetrics_devtensor
+                trnMetrics_g
             )
 
             loss_var.backward()
             self.optimizer.step()
             del loss_var
 
-        self.totalTrainingSamples_count += trainingMetrics_devtensor.size(1)
+        self.totalTrainingSamples_count += trnMetrics_g.size(1)
 
-        return trainingMetrics_devtensor.to('cpu')
+        return trnMetrics_g.to('cpu')
 
 
-    def doTesting(self, epoch_ndx, test_dl):
+    def doValidation(self, epoch_ndx, val_dl):
         with torch.no_grad():
             self.model.eval()
-            testingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset)).to(self.device)
+            valMetrics_g = torch.zeros(METRICS_SIZE, len(val_dl.dataset)).to(self.device)
             batch_iter = enumerateWithEstimate(
-                test_dl,
-                "E{} Testing ".format(epoch_ndx),
-                start_ndx=test_dl.num_workers,
+                val_dl,
+                "E{} Validation ".format(epoch_ndx),
+                start_ndx=val_dl.num_workers,
             )
             for batch_ndx, batch_tup in batch_iter:
-                self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_devtensor)
+                self.computeBatchLoss(batch_ndx, batch_tup, val_dl.batch_size, valMetrics_g)
 
-        return testingMetrics_devtensor.to('cpu')
+        return valMetrics_g.to('cpu')
 
 
 
-    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_devtensor):
-        input_tensor, label_tensor, _series_list, _center_list = batch_tup
+    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g):
+        input_t, label_t, _series_list, _center_list = batch_tup
 
-        input_devtensor = input_tensor.to(self.device, non_blocking=True)
-        label_devtensor = label_tensor.to(self.device, non_blocking=True)
+        input_g = input_t.to(self.device, non_blocking=True)
+        label_g = label_t.to(self.device, non_blocking=True)
 
-        logits_devtensor, probability_devtensor = self.model(input_devtensor)
+        logits_g, probability_g = self.model(input_g)
 
         loss_func = nn.CrossEntropyLoss(reduction='none')
-        loss_devtensor = loss_func(logits_devtensor, label_devtensor[:,1])
+        loss_g = loss_func(
+            logits_g,
+            label_g[:,1],
+        )
         start_ndx = batch_ndx * batch_size
-        end_ndx = start_ndx + label_tensor.size(0)
+        end_ndx = start_ndx + label_t.size(0)
 
-        metrics_devtensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_devtensor[:,1]
-        metrics_devtensor[METRICS_PRED_NDX, start_ndx:end_ndx] = probability_devtensor[:,1]
-        metrics_devtensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_devtensor
+        metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_g[:,1]
+        metrics_g[METRICS_PRED_NDX, start_ndx:end_ndx] = probability_g[:,1]
+        metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_g
 
-        return loss_devtensor.mean()
+        return loss_g.mean()
 
 
     def logMetrics(
             self,
             epoch_ndx,
             mode_str,
-            metrics_tensor,
+            metrics_g,
     ):
         log.info("E{} {}".format(
             epoch_ndx,
             type(self).__name__,
         ))
 
-        metrics_ary = metrics_tensor.cpu().detach().numpy()
-#         assert np.isfinite(metrics_ary).all()
+#         metrics_a = metrics_t.cpu().detach().numpy()
+#         assert np.isfinite(metrics_a).all()
 
-        benLabel_mask = metrics_ary[METRICS_LABEL_NDX] <= 0.5
-        benPred_mask = metrics_ary[METRICS_PRED_NDX] <= 0.5
+        benLabel_mask = metrics_g[METRICS_LABEL_NDX] <= 0.5
+        benPred_mask = metrics_g[METRICS_PRED_NDX] <= 0.5
 
         malLabel_mask = ~benLabel_mask
         malPred_mask = ~benPred_mask
@@ -253,16 +256,16 @@ class LunaTrainingApp(object):
         # falsePos_count = benLabel_count - benCorrect_count
         # falseNeg_count = malLabel_count - malCorrect_count
 
-        # log.info(['min loss', metrics_ary[METRICS_LOSS_NDX, benLabel_mask].min(), metrics_ary[METRICS_LOSS_NDX, malLabel_mask].min()])
-        # log.info(['max loss', metrics_ary[METRICS_LOSS_NDX, benLabel_mask].max(), metrics_ary[METRICS_LOSS_NDX, malLabel_mask].max()])
+        # log.info(['min loss', metrics_a[METRICS_LOSS_NDX, benLabel_mask].min(), metrics_a[METRICS_LOSS_NDX, malLabel_mask].min()])
+        # log.info(['max loss', metrics_a[METRICS_LOSS_NDX, benLabel_mask].max(), metrics_a[METRICS_LOSS_NDX, malLabel_mask].max()])
 
 
         metrics_dict = {}
-        metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
-        metrics_dict['loss/ben'] = metrics_ary[METRICS_LOSS_NDX, benLabel_mask].mean()
-        metrics_dict['loss/mal'] = metrics_ary[METRICS_LOSS_NDX, malLabel_mask].mean()
+        metrics_dict['loss/all'] = metrics_g[METRICS_LOSS_NDX].mean()
+        metrics_dict['loss/ben'] = metrics_g[METRICS_LOSS_NDX, benLabel_mask].mean()
+        metrics_dict['loss/mal'] = metrics_g[METRICS_LOSS_NDX, malLabel_mask].mean()
 
-        metrics_dict['correct/all'] = (malCorrect_count + benCorrect_count) / metrics_ary.shape[1] * 100
+        metrics_dict['correct/all'] = (malCorrect_count + benCorrect_count) / metrics_g.shape[1] * 100
         metrics_dict['correct/ben'] = (benCorrect_count) / benLabel_count * 100
         metrics_dict['correct/mal'] = (malCorrect_count) / malLabel_count * 100
 
@@ -308,27 +311,27 @@ class LunaTrainingApp(object):
 
         writer.add_pr_curve(
             'pr',
-            metrics_ary[METRICS_LABEL_NDX],
-            metrics_ary[METRICS_PRED_NDX],
+            metrics_g[METRICS_LABEL_NDX],
+            metrics_g[METRICS_PRED_NDX],
             self.totalTrainingSamples_count,
         )
 
         bins = [x/50.0 for x in range(51)]
 
-        benHist_mask = benLabel_mask & (metrics_ary[METRICS_PRED_NDX] > 0.01)
-        malHist_mask = malLabel_mask & (metrics_ary[METRICS_PRED_NDX] < 0.99)
+        benHist_mask = benLabel_mask & (metrics_g[METRICS_PRED_NDX] > 0.01)
+        malHist_mask = malLabel_mask & (metrics_g[METRICS_PRED_NDX] < 0.99)
 
         if benHist_mask.any():
             writer.add_histogram(
                 'is_ben',
-                metrics_ary[METRICS_PRED_NDX, benHist_mask],
+                metrics_g[METRICS_PRED_NDX, benHist_mask],
                 self.totalTrainingSamples_count,
                 bins=bins,
             )
         if malHist_mask.any():
             writer.add_histogram(
                 'is_mal',
-                metrics_ary[METRICS_PRED_NDX, malHist_mask],
+                metrics_g[METRICS_PRED_NDX, malHist_mask],
                 self.totalTrainingSamples_count,
                 bins=bins,
             )
@@ -357,7 +360,7 @@ class LunaTrainingApp(object):
     #                 writer.add_histogram(
     #                     name.rsplit('.', 1)[-1] + '/' + name,
     #                     param.data.cpu().numpy(),
-    #                     # metrics_ary[METRICS_PRED_NDX, benHist_mask],
+    #                     # metrics_a[METRICS_PRED_NDX, benHist_mask],
     #                     self.totalTrainingSamples_count,
     #                     # bins=bins,
     #                 )

+ 46 - 33
p2ch10/vis.py

@@ -4,16 +4,16 @@ matplotlib.use('nbagg')
 import numpy as np
 import matplotlib.pyplot as plt
 
-from p2ch11_old.dsets import Ct, LunaDataset
+from p2ch10.dsets import Ct, LunaDataset
 
-clim=(0.0, 1.3)
+clim=(-1000.0, 300)
 
-def findMalignantSamples(start_ndx=0, limit=10):
+def findMalignantSamples(start_ndx=0, limit=100):
     ds = LunaDataset()
 
     malignantSample_list = []
-    for sample_tup in ds.sample_list:
-        if sample_tup[2]:
+    for sample_tup in ds.noduleInfo_list:
+        if sample_tup.isMalignant_bool:
             print(len(malignantSample_list), sample_tup)
             malignantSample_list.append(sample_tup)
 
@@ -24,7 +24,7 @@ def findMalignantSamples(start_ndx=0, limit=10):
 
 def showNodule(series_uid, batch_ndx=None, **kwargs):
     ds = LunaDataset(series_uid=series_uid, **kwargs)
-    malignant_list = [i for i, x in enumerate(ds.sample_list) if x[2]]
+    malignant_list = [i for i, x in enumerate(ds.noduleInfo_list) if x.isMalignant_bool]
 
     if batch_ndx is None:
         if malignant_list:
@@ -34,53 +34,66 @@ def showNodule(series_uid, batch_ndx=None, **kwargs):
             batch_ndx = 0
 
     ct = Ct(series_uid)
-    # ct_tensor, malignant_tensor, series_uid, center_irc = ds[batch_ndx]
-    malignant_tensor, diameter_mm, series_uid, center_irc, nodule_tensor = ds[batch_ndx]
-    ct_ary = nodule_tensor[1].numpy()
+    ct_t, malignant_t, series_uid, center_irc = ds[batch_ndx]
+    ct_a = ct_t[0].numpy()
 
-
-    fig = plt.figure(figsize=(15, 25))
+    fig = plt.figure(figsize=(30, 50))
 
     group_list = [
-        #[0,1,2],
-        [3,4,5],
-        [6,7,8],
-        [9,10,11],
-        #[12,13,14],
-        #[15]
+        [9, 11, 13],
+        [15, 16, 17],
+        [19, 21, 23],
     ]
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 1)
-    subplot.set_title('index {}'.format(int(center_irc.index)))
-    plt.imshow(ct.ary[int(center_irc.index)], clim=clim, cmap='gray')
+    subplot.set_title('index {}'.format(int(center_irc.index)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct.hu_a[int(center_irc.index)], clim=clim, cmap='gray')
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 2)
-    subplot.set_title('row {}'.format(int(center_irc.row)))
-    plt.imshow(ct.ary[:,int(center_irc.row)], clim=clim, cmap='gray')
+    subplot.set_title('row {}'.format(int(center_irc.row)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct.hu_a[:,int(center_irc.row)], clim=clim, cmap='gray')
+    plt.gca().invert_yaxis()
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 3)
-    subplot.set_title('col {}'.format(int(center_irc.col)))
-    plt.imshow(ct.ary[:,:,int(center_irc.col)], clim=clim, cmap='gray')
+    subplot.set_title('col {}'.format(int(center_irc.col)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct.hu_a[:,:,int(center_irc.col)], clim=clim, cmap='gray')
+    plt.gca().invert_yaxis()
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 4)
-    subplot.set_title('index {}'.format(int(center_irc.index)))
-    plt.imshow(ct_ary[ct_ary.shape[0]//2], clim=clim, cmap='gray')
+    subplot.set_title('index {}'.format(int(center_irc.index)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct_a[ct_a.shape[0]//2], clim=clim, cmap='gray')
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 5)
-    subplot.set_title('row {}'.format(int(center_irc.row)))
-    plt.imshow(ct_ary[:,ct_ary.shape[1]//2], clim=clim, cmap='gray')
+    subplot.set_title('row {}'.format(int(center_irc.row)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct_a[:,ct_a.shape[1]//2], clim=clim, cmap='gray')
+    plt.gca().invert_yaxis()
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 6)
-    subplot.set_title('col {}'.format(int(center_irc.col)))
-    plt.imshow(ct_ary[:,:,ct_ary.shape[2]//2], clim=clim, cmap='gray')
+    subplot.set_title('col {}'.format(int(center_irc.col)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct_a[:,:,ct_a.shape[2]//2], clim=clim, cmap='gray')
+    plt.gca().invert_yaxis()
 
     for row, index_list in enumerate(group_list):
         for col, index in enumerate(index_list):
             subplot = fig.add_subplot(len(group_list) + 2, 3, row * 3 + col + 7)
-            subplot.set_title('slice {}'.format(index))
-            plt.imshow(ct_ary[index*2], clim=clim, cmap='gray')
+            subplot.set_title('slice {}'.format(index), fontsize=30)
+            for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+                label.set_fontsize(20)
+            plt.imshow(ct_a[index], clim=clim, cmap='gray')
+
 
+    print(series_uid, batch_ndx, bool(malignant_t[0]), malignant_list)
 
-    print(series_uid, batch_ndx, bool(malignant_tensor[0]), malignant_list, ct.vxSize_xyz)
 
-    return ct_ary

File diff suppressed because it is too large
+ 77 - 63
p2ch10_explore_data.ipynb


+ 7 - 7
p2ch11/1_final_metric_f1_score.ipynb

@@ -11,10 +11,10 @@
     "import numpy as np\n",
     "\n",
     "# Make precision and recall data.\n",
-    "range_ary = np.arange(0.01, 1, 0.01)\n",
-    "precision_ary, recall_ary = np.meshgrid(range_ary, range_ary)\n",
+    "range_a = np.arange(0.01, 1, 0.01)\n",
+    "precision_a, recall_a = np.meshgrid(range_a, range_a)\n",
     "\n",
-    "f1_score = np.sqrt(2 * precision_ary * recall_ary / (precision_ary + recall_ary))\n",
+    "f1_score = np.sqrt(2 * precision_a * recall_a / (precision_a + recall_a))\n",
     "\n",
     "def plotScore(title_str, other_score):\n",
     "    fig, subplts = plt.subplots(nrows=1, ncols=1, dpi=300, figsize=(7/2, 2.5))\n",
@@ -67,7 +67,7 @@
     }
    ],
    "source": [
-    "add_score = (precision_ary + recall_ary) / 2\n",
+    "add_score = (precision_a + recall_a) / 2\n",
     "plotScores(\"avg\", add_score)"
    ]
   },
@@ -88,7 +88,7 @@
     }
    ],
    "source": [
-    "min_score = np.min(np.array([precision_ary, recall_ary]), axis=0)\n",
+    "min_score = np.min(np.array([precision_a, recall_a]), axis=0)\n",
     "plotScores(\"min\", min_score)"
    ]
   },
@@ -109,7 +109,7 @@
     }
    ],
    "source": [
-    "mult_score = precision_ary * recall_ary\n",
+    "mult_score = precision_a * recall_a\n",
     "plotScores(\"mult\", mult_score)"
    ]
   },
@@ -130,7 +130,7 @@
     }
    ],
    "source": [
-    "sqrt_score = np.sqrt(precision_ary * recall_ary)\n",
+    "sqrt_score = np.sqrt(precision_a * recall_a)\n",
     "plotScores(\"sqrt\", sqrt_score)"
    ]
   }

+ 39 - 163
p2ch11/dsets.py

@@ -2,22 +2,17 @@ import copy
 import csv
 import functools
 import glob
-import math
 import os
 import random
 
 from collections import namedtuple
 
 import SimpleITK as sitk
-
 import numpy as np
+
 import torch
 import torch.cuda
 from torch.utils.data import Dataset
-import torch.nn as nn
-import torch.nn.functional as F
-
-import numpy as np
 
 from util.disk import getCache
 from util.util import XyzTuple, xyz2irc
@@ -80,18 +75,18 @@ class Ct(object):
         mhd_path = glob.glob('data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
 
         ct_mhd = sitk.ReadImage(mhd_path)
-        ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
+        ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
 
         # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
         # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
         # This gets rid of negative density stuff used to indicate out-of-FOV
-        ct_ary[ct_ary < -1000] = -1000
+        ct_a[ct_a < -1000] = -1000
 
         # This nukes any weird hotspots and clamps bone down
-        ct_ary[ct_ary > 1000] = 1000
+        ct_a[ct_a > 1000] = 1000
 
         self.series_uid = series_uid
-        self.ary = ct_ary
+        self.hu_a = ct_a
 
         self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
         self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
@@ -105,23 +100,23 @@ class Ct(object):
             start_ndx = int(round(center_val - width_irc[axis]/2))
             end_ndx = int(start_ndx + width_irc[axis])
 
-            assert center_val >= 0 and center_val < self.ary.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
+            assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
 
             if start_ndx < 0:
                 # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
-                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
+                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
                 start_ndx = 0
                 end_ndx = int(width_irc[axis])
 
-            if end_ndx > self.ary.shape[axis]:
+            if end_ndx > self.hu_a.shape[axis]:
                 # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
-                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
-                end_ndx = self.ary.shape[axis]
-                start_ndx = int(self.ary.shape[axis] - width_irc[axis])
+                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
+                end_ndx = self.hu_a.shape[axis]
+                start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])
 
             slice_list.append(slice(start_ndx, end_ndx))
 
-        ct_chunk = self.ary[tuple(slice_list)]
+        ct_chunk = self.hu_a[tuple(slice_list)]
 
         return ct_chunk, center_irc
 
@@ -136,181 +131,62 @@ def getCtRawNodule(series_uid, center_xyz, width_irc):
     ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
     return ct_chunk, center_irc
 
-def getCtAugmentedNodule(
-        augmentation_dict,
-        series_uid, center_xyz, width_irc,
-        use_cache=True):
-    if use_cache:
-        ct_chunk, center_irc = getCtRawNodule(series_uid, center_xyz, width_irc)
-    else:
-        ct = getCt(series_uid)
-        ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
-
-    ct_tensor = torch.tensor(ct_chunk).unsqueeze(0).unsqueeze(0).to(torch.float32)
-
-    transform_tensor = torch.eye(4).to(torch.float64)
-    # ... <1>
-
-    for i in range(3):
-        if 'flip' in augmentation_dict:
-            if random.random() > 0.5:
-                transform_tensor[i,i] *= -1
-
-        if 'offset' in augmentation_dict:
-            offset_float = augmentation_dict['offset']
-            random_float = (random.random() * 2 - 1)
-            transform_tensor[3,i] = offset_float * random_float
-
-        if 'scale' in augmentation_dict:
-            scale_float = augmentation_dict['scale']
-            random_float = (random.random() * 2 - 1)
-            transform_tensor[i,i] *= 1.0 + scale_float * random_float
-
-
-    if 'rotate' in augmentation_dict:
-        angle_rad = random.random() * math.pi * 2
-        s = math.sin(angle_rad)
-        c = math.cos(angle_rad)
-
-        rotation_tensor = torch.tensor([
-            [c, -s, 0, 0],
-            [s, c, 0, 0],
-            [0, 0, 1, 0],
-            [0, 0, 0, 1],
-        ], dtype=torch.float64)
-
-        transform_tensor @= rotation_tensor
-
-    affine_tensor = F.affine_grid(
-            transform_tensor[:3].unsqueeze(0).to(torch.float32),
-            ct_tensor.size(),
-        )
-
-    augmented_chunk = F.grid_sample(
-            ct_tensor,
-            affine_tensor,
-            padding_mode='border'
-        ).to('cpu')
-
-    if 'noise' in augmentation_dict:
-        noise_tensor = torch.randn_like(augmented_chunk)
-        noise_tensor *= augmentation_dict['noise']
-
-        augmented_chunk += noise_tensor
-
-    return augmented_chunk[0], center_irc
-
 
 class LunaDataset(Dataset):
     def __init__(self,
-                 test_stride=0,
-                 isTestSet_bool=None,
+                 val_stride=0,
+                 isValSet_bool=None,
                  series_uid=None,
                  sortby_str='random',
-                 ratio_int=0,
-                 augmentation_dict=None,
-                 noduleInfo_list=None,
             ):
-        self.ratio_int = ratio_int
-        self.augmentation_dict = augmentation_dict
-
-        if noduleInfo_list:
-            self.noduleInfo_list = copy.copy(noduleInfo_list)
-            self.use_cache = False
-        else:
-            self.noduleInfo_list = copy.copy(getNoduleInfoList())
-            self.use_cache = True
+        self.noduleInfo_list = copy.copy(getNoduleInfoList())
 
         if series_uid:
             self.noduleInfo_list = [x for x in self.noduleInfo_list if x.series_uid == series_uid]
 
-        if test_stride > 1:
-            if isTestSet_bool:
-                self.noduleInfo_list = self.noduleInfo_list[::test_stride]
-            else:
-                del self.noduleInfo_list[::test_stride]
+        if isValSet_bool:
+            assert val_stride > 0, val_stride
+            self.noduleInfo_list = self.noduleInfo_list[::val_stride]
+            assert self.noduleInfo_list
+        elif val_stride > 0:
+            del self.noduleInfo_list[::val_stride]
+            assert self.noduleInfo_list
 
         if sortby_str == 'random':
             random.shuffle(self.noduleInfo_list)
         elif sortby_str == 'series_uid':
-            self.noduleInfo_list.sort(key=lambda x: (x[2], x[3])) # sorting by series_uid, center_xyz)
+            self.noduleInfo_list.sort(key=lambda x: (x.series_uid, x.center_xyz))
         elif sortby_str == 'malignancy_size':
             pass
         else:
             raise Exception("Unknown sort: " + repr(sortby_str))
 
-        self.benign_list = [nt for nt in self.noduleInfo_list if not nt.isMalignant_bool]
-        self.malignant_list = [nt for nt in self.noduleInfo_list if nt.isMalignant_bool]
-
-        log.info("{!r}: {} {} samples, {} ben, {} mal, {} ratio".format(
+        log.info("{!r}: {} {} samples".format(
             self,
             len(self.noduleInfo_list),
-            "testing" if isTestSet_bool else "training",
-            len(self.benign_list),
-            len(self.malignant_list),
-            '{}:1'.format(self.ratio_int) if self.ratio_int else 'unbalanced'
+            "validation" if isValSet_bool else "training",
         ))
 
-    def shuffleSamples(self):
-        if self.ratio_int:
-            random.shuffle(self.benign_list)
-            random.shuffle(self.malignant_list)
-
     def __len__(self):
-        if self.ratio_int:
-            return 200000
-        else:
-            return len(self.noduleInfo_list)
+        return len(self.noduleInfo_list)
 
     def __getitem__(self, ndx):
-        if self.ratio_int:
-            malignant_ndx = ndx // (self.ratio_int + 1)
-
-            if ndx % (self.ratio_int + 1):
-                benign_ndx = ndx - 1 - malignant_ndx
-                benign_ndx %= len(self.benign_list)
-                nodule_tup = self.benign_list[benign_ndx]
-            else:
-                malignant_ndx %= len(self.malignant_list)
-                nodule_tup = self.malignant_list[malignant_ndx]
-        else:
-            nodule_tup = self.noduleInfo_list[ndx]
-
-        width_irc = (24, 48, 48)
-
-        if self.augmentation_dict:
-            nodule_t, center_irc = getCtAugmentedNodule(
-                self.augmentation_dict,
-                nodule_tup.series_uid,
-                nodule_tup.center_xyz,
-                width_irc,
-                self.use_cache,
-            )
-        elif self.use_cache:
-            nodule_ary, center_irc = getCtRawNodule(
-                nodule_tup.series_uid,
-                nodule_tup.center_xyz,
-                width_irc,
-            )
-            nodule_t = torch.from_numpy(nodule_ary).to(torch.float32)
-            nodule_t = nodule_t.unsqueeze(0)
-        else:
-            ct = getCt(nodule_tup.series_uid)
-            nodule_ary, center_irc = ct.getRawNodule(
-                nodule_tup.center_xyz,
-                width_irc,
-            )
-            nodule_t = torch.from_numpy(nodule_ary).to(torch.float32)
-            nodule_t = nodule_t.unsqueeze(0)
-
-        malignant_tensor = torch.tensor([
+        nodule_tup = self.noduleInfo_list[ndx]
+        width_irc = (32, 48, 48)
+
+        nodule_a, center_irc = getCtRawNodule(
+            nodule_tup.series_uid,
+            nodule_tup.center_xyz,
+            width_irc,
+        )
+        nodule_t = torch.from_numpy(nodule_a).to(torch.float32)
+        nodule_t = nodule_t.unsqueeze(0)
+
+        malignant_t = torch.tensor([
                 not nodule_tup.isMalignant_bool,
                 nodule_tup.isMalignant_bool
             ],
             dtype=torch.long,
         )
 
-        return nodule_t, malignant_tensor, nodule_tup.series_uid, center_irc
-
-
-
+        return nodule_t, malignant_t, nodule_tup.series_uid, torch.tensor(center_irc)

+ 45 - 28
p2ch11/model.py

@@ -1,6 +1,6 @@
 import math
 
-import torch.nn as nn
+from torch import nn as nn
 
 from util.logconf import logging
 
@@ -11,53 +11,70 @@ log.setLevel(logging.DEBUG)
 
 
 class LunaModel(nn.Module):
-    def __init__(self, layer_count=4, in_channels=1, conv_channels=8):
+    def __init__(self, in_channels=1, conv_channels=8):
         super().__init__()
 
-        self.input_batchnorm = nn.BatchNorm2d(1)
+        self.tail_batchnorm = nn.BatchNorm3d(1)
 
-        layer_list = []
-        for layer_ndx in range(layer_count):
-            layer_list += [
-                nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True),
-                nn.ReLU(inplace=True),
-                nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=True),
-                nn.ReLU(inplace=True),
-                nn.MaxPool3d(2, 2),
-            ]
+        self.block1 = LunaBlock(in_channels, conv_channels)
+        self.block2 = LunaBlock(conv_channels, conv_channels * 2)
+        self.block3 = LunaBlock(conv_channels * 2, conv_channels * 4)
+        self.block4 = LunaBlock(conv_channels * 4, conv_channels * 8)
 
-            in_channels = conv_channels
-            conv_channels *= 2
-
-        self.convAndPool_seq = nn.Sequential(*layer_list)
-        self.fullyConnected_layer = nn.Linear(576, 2)
-        self.final = nn.Softmax(dim=1)
+        self.head_linear = nn.Linear(1152, 2)
+        self.head_softmax = nn.Softmax(dim=1)
 
         self._init_weights()
 
+    # see also https://github.com/pytorch/pytorch/issues/18182
     def _init_weights(self):
-        # see also https://github.com/pytorch/pytorch/issues/18182
         for m in self.modules():
             if type(m) in {
-                nn.Conv2d,
+                nn.Linear,
                 nn.Conv3d,
+                nn.Conv2d,
                 nn.ConvTranspose2d,
                 nn.ConvTranspose3d,
-                nn.Linear,
             }:
-                # log.debug(m)
-                # nn.init.kaiming_normal_(m.weight.data, mode='fan_out', a=0)
                 nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out', nonlinearity='relu')
                 if m.bias is not None:
                     fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                     bound = 1 / math.sqrt(fan_out)
                     nn.init.normal_(m.bias, -bound, bound)
 
+
     def forward(self, input_batch):
-        bn_output = self.input_batchnorm(input_batch)
-        conv_output = self.convAndPool_seq(bn_output)
-        conv_flat = conv_output.view(conv_output.size(0), -1)
-        classifier_output = self.fullyConnected_layer(conv_flat)
+        bn_output = self.tail_batchnorm(input_batch)
+
+        block_out = self.block1(bn_output)
+        block_out = self.block2(block_out)
+        block_out = self.block3(block_out)
+        block_out = self.block4(block_out)
+
+        conv_flat = block_out.view(
+            block_out.size(0),
+            -1,
+        )
+        linear_output = self.head_linear(conv_flat)
 
-        return classifier_output, self.final(classifier_output)
+        return linear_output, self.head_softmax(linear_output)
+
+
+class LunaBlock(nn.Module):
+    def __init__(self, in_channels, conv_channels):
+        super().__init__()
+
+        self.conv1 = nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True)
+        self.relu1 = nn.ReLU(inplace=True)
+        self.conv2 = nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=True)
+        self.relu2 = nn.ReLU(inplace=True)
+
+        self.maxpool = nn.MaxPool3d(2, 2)
+
+    def forward(self, input_batch):
+        block_out = self.conv1(input_batch)
+        block_out = self.relu1(block_out)
+        block_out = self.conv2(block_out)
+        block_out = self.relu2(block_out)
 
+        return self.maxpool(block_out)

+ 1 - 1
p2ch11/prepcache.py

@@ -60,4 +60,4 @@ class LunaPrepCacheApp(object):
 
 
 if __name__ == '__main__':
-    sys.exit(LunaPrepCacheApp().main() or 0)
+    LunaPrepCacheApp().main()

+ 103 - 152
p2ch11/training.py

@@ -5,11 +5,11 @@ import sys
 
 import numpy as np
 
-from tensorboardX import SummaryWriter
+from torch.utils.tensorboard import SummaryWriter
 
 import torch
 import torch.nn as nn
-from torch.optim import SGD
+from torch.optim import SGD, Adan
 from torch.utils.data import DataLoader
 
 from util.util import enumerateWithEstimate
@@ -22,7 +22,7 @@ log = logging.getLogger(__name__)
 log.setLevel(logging.INFO)
 # log.setLevel(logging.DEBUG)
 
-# Used for computeBatchLoss and logMetrics to index into metrics_tensor/metrics_ary
+# Used for computeBatchLoss and logMetrics to index into metrics_t/metrics_a
 METRICS_LABEL_NDX=0
 METRICS_PRED_NDX=1
 METRICS_LOSS_NDX=2
@@ -34,85 +34,38 @@ class LunaTrainingApp(object):
             sys_argv = sys.argv[1:]
 
         parser = argparse.ArgumentParser()
-        parser.add_argument('--batch-size',
-            help='Batch size to use for training',
-            default=32,
-            type=int,
-        )
         parser.add_argument('--num-workers',
             help='Number of worker processes for background data loading',
             default=8,
             type=int,
         )
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=32,
+            type=int,
+        )
         parser.add_argument('--epochs',
             help='Number of epochs to train for',
             default=1,
             type=int,
         )
-        parser.add_argument('--balanced',
-            help="Balance the training data to half benign, half malignant.",
-            action='store_true',
-            default=False,
-        )
-        parser.add_argument('--augmented',
-            help="Augment the training data.",
-            action='store_true',
-            default=False,
-        )
-        parser.add_argument('--augment-flip',
-            help="Augment the training data by randomly flipping the data left-right, up-down, and front-back.",
-            action='store_true',
-            default=False,
-        )
-        parser.add_argument('--augment-offset',
-            help="Augment the training data by randomly offsetting the data slightly along the X and Y axes.",
-            action='store_true',
-            default=False,
-        )
-        parser.add_argument('--augment-scale',
-            help="Augment the training data by randomly increasing or decreasing the size of the nodule.",
-            action='store_true',
-            default=False,
-        )
-        parser.add_argument('--augment-rotate',
-            help="Augment the training data by randomly rotating the data around the head-foot axis.",
-            action='store_true',
-            default=False,
-        )
-        parser.add_argument('--augment-noise',
-            help="Augment the training data by randomly adding noise to the data.",
-            action='store_true',
-            default=False,
-        )
 
         parser.add_argument('--tb-prefix',
             default='p2ch11',
             help="Data prefix to use for Tensorboard run. Defaults to chapter.",
         )
+
         parser.add_argument('comment',
             help="Comment suffix for Tensorboard run.",
             nargs='?',
-            default='none',
+            default='dwlpt',
         )
-
         self.cli_args = parser.parse_args(sys_argv)
         self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
 
-        self.totalTrainingSamples_count = 0
         self.trn_writer = None
-        self.tst_writer = None
-
-        self.augmentation_dict = {}
-        if self.cli_args.augmented or self.cli_args.augment_flip:
-            self.augmentation_dict['flip'] = True
-        if self.cli_args.augmented or self.cli_args.augment_offset:
-            self.augmentation_dict['offset'] = 0.1
-        if self.cli_args.augmented or self.cli_args.augment_scale:
-            self.augmentation_dict['scale'] = 0.2
-        if self.cli_args.augmented or self.cli_args.augment_rotate:
-            self.augmentation_dict['rotate'] = True
-        if self.cli_args.augmented or self.cli_args.augment_noise:
-            self.augmentation_dict['noise'] = 25.0
+        self.val_writer = None
+        self.totalTrainingSamples_count = 0
 
         self.use_cuda = torch.cuda.is_available()
         self.device = torch.device("cuda" if self.use_cuda else "cpu")
@@ -120,10 +73,10 @@ class LunaTrainingApp(object):
         self.model = self.initModel()
         self.optimizer = self.initOptimizer()
 
-
     def initModel(self):
         model = LunaModel()
         if self.use_cuda:
+            log.info("Using CUDA with {} devices.".format(torch.cuda.device_count()))
             if torch.cuda.device_count() > 1:
                 model = nn.DataParallel(model)
             model = model.to(self.device)
@@ -135,10 +88,8 @@ class LunaTrainingApp(object):
 
     def initTrainDl(self):
         train_ds = LunaDataset(
-            test_stride=10,
-            isTestSet_bool=False,
-            ratio_int=int(self.cli_args.balanced),
-            augmentation_dict=self.augmentation_dict,
+            val_stride=10,
+            isValSet_bool=False,
         )
 
         train_dl = DataLoader(
@@ -150,40 +101,36 @@ class LunaTrainingApp(object):
 
         return train_dl
 
-    def initTestDl(self):
-        test_ds = LunaDataset(
-            test_stride=10,
-            isTestSet_bool=True,
+    def initValDl(self):
+        val_ds = LunaDataset(
+            val_stride=10,
+            isValSet_bool=True,
         )
 
-        test_dl = DataLoader(
-            test_ds,
+        val_dl = DataLoader(
+            val_ds,
             batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
             num_workers=self.cli_args.num_workers,
             pin_memory=self.use_cuda,
         )
 
-        return test_dl
+        return val_dl
 
     def initTensorboardWriters(self):
         if self.trn_writer is None:
             log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
 
-            self.trn_writer = SummaryWriter(log_dir=log_dir + '_trn_cls_' + self.cli_args.comment)
-            self.tst_writer = SummaryWriter(log_dir=log_dir + '_tst_cls_' + self.cli_args.comment)
-# eng::tb_writer[]
+            self.trn_writer = SummaryWriter(log_dir=log_dir + '-trn_cls-' + self.cli_args.comment)
+            self.val_writer = SummaryWriter(log_dir=log_dir + '-val_cls-' + self.cli_args.comment)
 
 
     def main(self):
         log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
 
         train_dl = self.initTrainDl()
-        test_dl = self.initTestDl()
+        val_dl = self.initValDl()
 
         self.initTensorboardWriters()
-        # self.logModelMetrics(self.model)
-
-        # best_score = 0.0
 
         for epoch_ndx in range(1, self.cli_args.epochs + 1):
 
@@ -191,26 +138,25 @@ class LunaTrainingApp(object):
                 epoch_ndx,
                 self.cli_args.epochs,
                 len(train_dl),
-                len(test_dl),
+                len(val_dl),
                 self.cli_args.batch_size,
                 (torch.cuda.device_count() if self.use_cuda else 1),
             ))
 
-            trnMetrics_tensor = self.doTraining(epoch_ndx, train_dl)
-            self.logMetrics(epoch_ndx, 'trn', trnMetrics_tensor)
+            trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
+            self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)
 
-            tstMetrics_tensor = self.doTesting(epoch_ndx, test_dl)
-            self.logMetrics(epoch_ndx, 'tst', tstMetrics_tensor)
+            valMetrics_t = self.doValidation(epoch_ndx, val_dl)
+            self.logMetrics(epoch_ndx, 'val', valMetrics_t)
 
         if hasattr(self, 'trn_writer'):
             self.trn_writer.close()
-            self.tst_writer.close()
+            self.val_writer.close()
 
 
     def doTraining(self, epoch_ndx, train_dl):
         self.model.train()
-        train_dl.dataset.shuffleSamples()
-        trainingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(train_dl.dataset)).to(self.device)
+        trnMetrics_g = torch.zeros(METRICS_SIZE, len(train_dl.dataset)).to(self.device)
         batch_iter = enumerateWithEstimate(
             train_dl,
             "E{} Training".format(epoch_ndx),
@@ -223,104 +169,110 @@ class LunaTrainingApp(object):
                 batch_ndx,
                 batch_tup,
                 train_dl.batch_size,
-                trainingMetrics_devtensor
+                trnMetrics_g
             )
 
             loss_var.backward()
             self.optimizer.step()
             del loss_var
 
-        self.totalTrainingSamples_count += trainingMetrics_devtensor.size(1)
+            # # This is for adding the model graph to TensorBoard.
+            # if epoch_ndx == 1 and batch_ndx == 0:
+            #     with torch.no_grad():
+            #         model = LunaModel()
+            #         self.trn_writer.add_graph(model, batch_tup[0], verbose=True)
+            #         self.trn_writer.close()
+
+        self.totalTrainingSamples_count += len(train_dl.dataset)
 
-        return trainingMetrics_devtensor.to('cpu')
+        return trnMetrics_g.to('cpu')
 
 
-    def doTesting(self, epoch_ndx, test_dl):
+    def doValidation(self, epoch_ndx, val_dl):
         with torch.no_grad():
             self.model.eval()
-            testingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset)).to(self.device)
+            valMetrics_g = torch.zeros(METRICS_SIZE, len(val_dl.dataset)).to(self.device)
             batch_iter = enumerateWithEstimate(
-                test_dl,
-                "E{} Testing ".format(epoch_ndx),
-                start_ndx=test_dl.num_workers,
+                val_dl,
+                "E{} Validation ".format(epoch_ndx),
+                start_ndx=val_dl.num_workers,
             )
             for batch_ndx, batch_tup in batch_iter:
-                self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_devtensor)
+                self.computeBatchLoss(batch_ndx, batch_tup, val_dl.batch_size, valMetrics_g)
 
-        return testingMetrics_devtensor.to('cpu')
+        return valMetrics_g.to('cpu')
 
 
 
-    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_devtensor):
-        input_tensor, label_tensor, _series_list, _center_list = batch_tup
+    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g):
+        input_t, label_t, _series_list, _center_list = batch_tup
 
-        input_devtensor = input_tensor.to(self.device, non_blocking=True)
-        label_devtensor = label_tensor.to(self.device, non_blocking=True)
+        input_g = input_t.to(self.device, non_blocking=True)
+        label_g = label_t.to(self.device, non_blocking=True)
 
-        logits_devtensor, probability_devtensor = self.model(input_devtensor)
+        logits_g, probability_g = self.model(input_g)
 
         loss_func = nn.CrossEntropyLoss(reduction='none')
-        loss_devtensor = loss_func(logits_devtensor, label_devtensor[:,1])
+        loss_g = loss_func(
+            logits_g,
+            label_g[:,1],
+        )
         start_ndx = batch_ndx * batch_size
-        end_ndx = start_ndx + label_tensor.size(0)
+        end_ndx = start_ndx + label_t.size(0)
 
-        metrics_devtensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_devtensor[:,1]
-        metrics_devtensor[METRICS_PRED_NDX, start_ndx:end_ndx] = probability_devtensor[:,1]
-        metrics_devtensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_devtensor
+        metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_g[:,1]
+        metrics_g[METRICS_PRED_NDX, start_ndx:end_ndx] = probability_g[:,1]
+        metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_g
 
-        return loss_devtensor.mean()
+        return loss_g.mean()
 
 
     def logMetrics(
             self,
             epoch_ndx,
             mode_str,
-            metrics_tensor,
+            metrics_t,
     ):
+        self.initTensorboardWriters()
         log.info("E{} {}".format(
             epoch_ndx,
             type(self).__name__,
         ))
 
-        metrics_ary = metrics_tensor.cpu().detach().numpy()
-#         assert np.isfinite(metrics_ary).all()
-
-        benLabel_mask = metrics_ary[METRICS_LABEL_NDX] <= 0.5
-        benPred_mask = metrics_ary[METRICS_PRED_NDX] <= 0.5
+        benLabel_mask = metrics_t[METRICS_LABEL_NDX] <= 0.5
+        benPred_mask = metrics_t[METRICS_PRED_NDX] <= 0.5
 
         malLabel_mask = ~benLabel_mask
         malPred_mask = ~benPred_mask
 
-        benLabel_count = benLabel_mask.sum()
-        malLabel_count = malLabel_mask.sum()
+        ben_count = benLabel_mask.sum()
+        mal_count = malLabel_mask.sum()
 
-        trueNeg_count = benCorrect_count = (benLabel_mask & benPred_mask).sum()
-        truePos_count = malCorrect_count = (malLabel_mask & malPred_mask).sum()
+        ben_correct = (benLabel_mask & benPred_mask).sum()
+        mal_correct = (malLabel_mask & malPred_mask).sum()
 
-        falsePos_count = benLabel_count - benCorrect_count
-        falseNeg_count = malLabel_count - malCorrect_count
+        # trueNeg_count = ben_correct = (benLabel_mask & benPred_mask).sum()
+        # truePos_count = mal_correct = (malLabel_mask & malPred_mask).sum()
+        #
+        # falsePos_count = ben_count - ben_correct
+        # falseNeg_count = malLabel_count - mal_correct
 
-        metrics_dict = {}
-        metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
-        metrics_dict['loss/ben'] = metrics_ary[METRICS_LOSS_NDX, benLabel_mask].mean()
-        metrics_dict['loss/mal'] = metrics_ary[METRICS_LOSS_NDX, malLabel_mask].mean()
+        # log.info(['min loss', metrics_a[METRICS_LOSS_NDX, benLabel_mask].min(), metrics_a[METRICS_LOSS_NDX, malLabel_mask].min()])
+        # log.info(['max loss', metrics_a[METRICS_LOSS_NDX, benLabel_mask].max(), metrics_a[METRICS_LOSS_NDX, malLabel_mask].max()])
 
-        metrics_dict['correct/all'] = (malCorrect_count + benCorrect_count) / metrics_ary.shape[1] * 100
-        metrics_dict['correct/ben'] = (benCorrect_count) / benLabel_count * 100
-        metrics_dict['correct/mal'] = (malCorrect_count) / malLabel_count * 100
 
-        precision = metrics_dict['pr/precision'] = truePos_count / (truePos_count + falsePos_count)
-        recall    = metrics_dict['pr/recall']    = truePos_count / (truePos_count + falseNeg_count)
+        metrics_dict = {}
+        metrics_dict['loss/all'] = metrics_t[METRICS_LOSS_NDX].mean()
+        metrics_dict['loss/ben'] = metrics_t[METRICS_LOSS_NDX, benLabel_mask].mean()
+        metrics_dict['loss/mal'] = metrics_t[METRICS_LOSS_NDX, malLabel_mask].mean()
 
-        metrics_dict['pr/f1_score'] = 2 * (precision * recall) / (precision + recall)
+        metrics_dict['correct/all'] = (mal_correct + ben_correct) / metrics_t.shape[1] * 100
+        metrics_dict['correct/ben'] = (ben_correct) / ben_count * 100
+        metrics_dict['correct/mal'] = (mal_correct) / mal_count * 100
 
         log.info(
-            ("E{} {:8} "
-                 + "{loss/all:.4f} loss, "
+            ("E{} {:8} {loss/all:.4f} loss, "
                  + "{correct/all:-5.1f}% correct, "
-                 + "{pr/precision:.4f} precision, "
-                 + "{pr/recall:.4f} recall, "
-                 + "{pr/f1_score:.4f} f1 score"
             ).format(
                 epoch_ndx,
                 mode_str,
@@ -328,29 +280,28 @@ class LunaTrainingApp(object):
             )
         )
         log.info(
-            ("E{} {:8} "
-                 + "{loss/ben:.4f} loss, "
-                 + "{correct/ben:-5.1f}% correct ({benCorrect_count:} of {benLabel_count:})"
+            ("E{} {:8} {loss/ben:.4f} loss, "
+                 + "{correct/ben:-5.1f}% correct ({ben_correct:} of {ben_count:})"
             ).format(
                 epoch_ndx,
                 mode_str + '_ben',
-                benCorrect_count=benCorrect_count,
-                benLabel_count=benLabel_count,
+                ben_correct=ben_correct,
+                ben_count=ben_count,
                 **metrics_dict,
             )
         )
         log.info(
-            ("E{} {:8} "
-                 + "{loss/mal:.4f} loss, "
-                 + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
+            ("E{} {:8} {loss/mal:.4f} loss, "
+                 + "{correct/mal:-5.1f}% correct ({mal_correct:} of {mal_count:})"
             ).format(
                 epoch_ndx,
                 mode_str + '_mal',
-                malCorrect_count=malCorrect_count,
-                malLabel_count=malLabel_count,
+                mal_correct=mal_correct,
+                mal_count=mal_count,
                 **metrics_dict,
             )
         )
+
         writer = getattr(self, mode_str + '_writer')
 
         for key, value in metrics_dict.items():
@@ -358,27 +309,27 @@ class LunaTrainingApp(object):
 
         writer.add_pr_curve(
             'pr',
-            metrics_ary[METRICS_LABEL_NDX],
-            metrics_ary[METRICS_PRED_NDX],
+            metrics_t[METRICS_LABEL_NDX],
+            metrics_t[METRICS_PRED_NDX],
             self.totalTrainingSamples_count,
         )
 
         bins = [x/50.0 for x in range(51)]
 
-        benHist_mask = benLabel_mask & (metrics_ary[METRICS_PRED_NDX] > 0.01)
-        malHist_mask = malLabel_mask & (metrics_ary[METRICS_PRED_NDX] < 0.99)
+        benHist_mask = benLabel_mask & (metrics_t[METRICS_PRED_NDX] > 0.01)
+        malHist_mask = malLabel_mask & (metrics_t[METRICS_PRED_NDX] < 0.99)
 
         if benHist_mask.any():
             writer.add_histogram(
                 'is_ben',
-                metrics_ary[METRICS_PRED_NDX, benHist_mask],
+                metrics_t[METRICS_PRED_NDX, benHist_mask],
                 self.totalTrainingSamples_count,
                 bins=bins,
             )
         if malHist_mask.any():
             writer.add_histogram(
                 'is_mal',
-                metrics_ary[METRICS_PRED_NDX, malHist_mask],
+                metrics_t[METRICS_PRED_NDX, malHist_mask],
                 self.totalTrainingSamples_count,
                 bins=bins,
             )
@@ -407,7 +358,7 @@ class LunaTrainingApp(object):
     #                 writer.add_histogram(
     #                     name.rsplit('.', 1)[-1] + '/' + name,
     #                     param.data.cpu().numpy(),
-    #                     # metrics_ary[METRICS_PRED_NDX, benHist_mask],
+    #                     # metrics_a[METRICS_PRED_NDX, benHist_mask],
     #                     self.totalTrainingSamples_count,
     #                     # bins=bins,
     #                 )
@@ -417,4 +368,4 @@ class LunaTrainingApp(object):
 
 
 if __name__ == '__main__':
-    sys.exit(LunaTrainingApp().main() or 0)
+    LunaTrainingApp().main()

+ 45 - 33
p2ch11/vis.py

@@ -6,14 +6,14 @@ import matplotlib.pyplot as plt
 
 from p2ch11.dsets import Ct, LunaDataset
 
-clim=(-1000.0, 1300)
+clim=(-1000.0, 300)
 
-def findMalignantSamples(start_ndx=0, limit=10):
+def findMalignantSamples(start_ndx=0, limit=100):
     ds = LunaDataset(sortby_str='malignancy_size')
 
     malignantSample_list = []
     for sample_tup in ds.noduleInfo_list:
-        if sample_tup[0]:
+        if sample_tup.isMalignant_bool:
             print(len(malignantSample_list), sample_tup)
             malignantSample_list.append(sample_tup)
 
@@ -23,8 +23,8 @@ def findMalignantSamples(start_ndx=0, limit=10):
     return malignantSample_list
 
 def showNodule(series_uid, batch_ndx=None, **kwargs):
-    ds = LunaDataset(series_uid=series_uid, sortby_str='malignancy_size', **kwargs)
-    malignant_list = [i for i, x in enumerate(ds.noduleInfo_list) if x[0]]
+    ds = LunaDataset(series_uid=series_uid, **kwargs)
+    malignant_list = [i for i, x in enumerate(ds.noduleInfo_list) if x.isMalignant_bool]
 
     if batch_ndx is None:
         if malignant_list:
@@ -34,54 +34,66 @@ def showNodule(series_uid, batch_ndx=None, **kwargs):
             batch_ndx = 0
 
     ct = Ct(series_uid)
-    # ct_tensor, malignant_tensor, series_uid, center_irc = ds[batch_ndx]
-    # malignant_tensor, diameter_mm, series_uid, center_irc, nodule_tensor = ds[batch_ndx]
-    nodule_tensor, malignant_tensor, series_uid, center_irc = ds[batch_ndx]
-    ct_ary = nodule_tensor[0].numpy()
+    ct_t, malignant_t, series_uid, center_irc = ds[batch_ndx]
+    ct_a = ct_t[0].numpy()
 
-
-    fig = plt.figure(figsize=(15, 25))
+    fig = plt.figure(figsize=(30, 50))
 
     group_list = [
-        #[0,1,2],
-        [3,4,5],
-        [6,7,8],
-        [9,10,11],
-        #[12,13,14],
-        #[15]
+        [9, 11, 13],
+        [15, 16, 17],
+        [19, 21, 23],
     ]
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 1)
-    subplot.set_title('index {}'.format(int(center_irc.index)))
-    plt.imshow(ct.ary[int(center_irc.index)], clim=clim, cmap='gray')
+    subplot.set_title('index {}'.format(int(center_irc.index)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct.hu_a[int(center_irc.index)], clim=clim, cmap='gray')
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 2)
-    subplot.set_title('row {}'.format(int(center_irc.row)))
-    plt.imshow(ct.ary[:,int(center_irc.row)], clim=clim, cmap='gray')
+    subplot.set_title('row {}'.format(int(center_irc.row)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct.hu_a[:,int(center_irc.row)], clim=clim, cmap='gray')
+    plt.gca().invert_yaxis()
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 3)
-    subplot.set_title('col {}'.format(int(center_irc.col)))
-    plt.imshow(ct.ary[:,:,int(center_irc.col)], clim=clim, cmap='gray')
+    subplot.set_title('col {}'.format(int(center_irc.col)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct.hu_a[:,:,int(center_irc.col)], clim=clim, cmap='gray')
+    plt.gca().invert_yaxis()
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 4)
-    subplot.set_title('index {}'.format(int(center_irc.index)))
-    plt.imshow(ct_ary[ct_ary.shape[0]//2], clim=clim, cmap='gray')
+    subplot.set_title('index {}'.format(int(center_irc.index)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct_a[ct_a.shape[0]//2], clim=clim, cmap='gray')
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 5)
-    subplot.set_title('row {}'.format(int(center_irc.row)))
-    plt.imshow(ct_ary[:,ct_ary.shape[1]//2], clim=clim, cmap='gray')
+    subplot.set_title('row {}'.format(int(center_irc.row)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct_a[:,ct_a.shape[1]//2], clim=clim, cmap='gray')
+    plt.gca().invert_yaxis()
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 6)
-    subplot.set_title('col {}'.format(int(center_irc.col)))
-    plt.imshow(ct_ary[:,:,ct_ary.shape[2]//2], clim=clim, cmap='gray')
+    subplot.set_title('col {}'.format(int(center_irc.col)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct_a[:,:,ct_a.shape[2]//2], clim=clim, cmap='gray')
+    plt.gca().invert_yaxis()
 
     for row, index_list in enumerate(group_list):
         for col, index in enumerate(index_list):
             subplot = fig.add_subplot(len(group_list) + 2, 3, row * 3 + col + 7)
-            subplot.set_title('slice {}'.format(index))
-            plt.imshow(ct_ary[index*2], clim=clim, cmap='gray')
+            subplot.set_title('slice {}'.format(index), fontsize=30)
+            for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+                label.set_fontsize(20)
+            plt.imshow(ct_a[index], clim=clim, cmap='gray')
+
 
+    print(series_uid, batch_ndx, bool(malignant_t[0]), malignant_list)
 
-    print(series_uid, batch_ndx, bool(malignant_tensor[1]), malignant_list, ct.vxSize_xyz)
 
-    return ct_ary

File diff suppressed because it is too large
+ 59 - 0
p2ch12/1_final_metric_f1_score.ipynb


+ 38 - 44
p2ch12/diagnose.py

@@ -52,7 +52,7 @@ class LunaDiagnoseApp(object):
         )
 
         parser.add_argument('--include-train',
-            help="Include data that was in the training set. (default: test data only)",
+            help="Include data that was in the training set. (default: validation data only)",
             action='store_true',
             default=False,
         )
@@ -177,13 +177,13 @@ class LunaDiagnoseApp(object):
     def main(self):
         log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
 
-        test_ds = LunaDataset(
-            test_stride=10,
-            isTestSet_bool=True,
+        val_ds = LunaDataset(
+            val_stride=10,
+            isValSet_bool=True,
         )
-        test_set = set(
+        val_set = set(
             noduleInfo_tup.series_uid
-            for noduleInfo_tup in test_ds.noduleInfo_list
+            for noduleInfo_tup in val_ds.noduleInfo_list
         )
         malignant_set = set(
             noduleInfo_tup.series_uid
@@ -199,22 +199,22 @@ class LunaDiagnoseApp(object):
                 for noduleInfo_tup in getNoduleInfoList()
             )
 
-        train_list = sorted(series_set - test_set) if self.cli_args.include_train else []
-        test_list = sorted(series_set & test_set)
+        train_list = sorted(series_set - val_set) if self.cli_args.include_train else []
+        val_list = sorted(series_set & val_set)
 
 
         noduleInfo_list = []
         series_iter = enumerateWithEstimate(
-            test_list + train_list,
+            val_list + train_list,
             "Series",
         )
         for _series_ndx, series_uid in series_iter:
-            ct, output_ary, _mask_ary, clean_ary = self.segmentCt(series_uid)
+            ct, output_a, _mask_a, clean_a = self.segmentCt(series_uid)
 
             noduleInfo_list += self.clusterSegmentationOutput(
                 series_uid,
                 ct,
-                clean_ary,
+                clean_a,
             )
 
             # if _series_ndx > 10:
@@ -230,21 +230,21 @@ class LunaDiagnoseApp(object):
             start_ndx=cls_dl.num_workers,
         )
         for batch_ndx, batch_tup in batch_iter:
-            input_tensor, _, series_list, center_list = batch_tup
+            input_t, _, series_list, center_list = batch_tup
 
-            input_devtensor = input_tensor.to(self.device)
+            input_g = input_t.to(self.device)
             with torch.no_grad():
-                _logits_devtensor, probability_devtensor = self.cls_model(input_devtensor)
+                _logits_g, probability_g = self.cls_model(input_g)
 
             classifications_list = zip(
                 series_list,
                 center_list,
-                probability_devtensor[:,1].to('cpu'),
+                probability_g[:,1].to('cpu'),
             )
 
             for cls_tup in classifications_list:
-                series_uid, center_irc, probablity_tensor = cls_tup
-                probablity_float = probablity_tensor.item()
+                series_uid, center_irc, probablity_t = cls_tup
+                probablity_float = probablity_t.item()
 
                 this_tup = (probablity_float, tuple(center_irc))
                 current_tup = series2diagnosis_dict.get(series_uid, this_tup)
@@ -257,60 +257,54 @@ class LunaDiagnoseApp(object):
                     log.debug([(type(x), x) for x in this_tup] + [(type(x), x) for x in current_tup])
                     raise
 
-                # self.logResults(
-                #     'Testing' if isTest_bool else 'Training',
-                #     [(series_uid, series2diagnosis_dict[series_uid])],
-                #     malignant_set,
-                # )
-
         log.info('Training set:')
         self.logResults('Training', train_list, series2diagnosis_dict, malignant_set)
 
-        log.info('Testing set:')
-        self.logResults('Testing', test_list, series2diagnosis_dict, malignant_set)
+        log.info('Validation set:')
+        self.logResults('Validation', val_list, series2diagnosis_dict, malignant_set)
 
     def segmentCt(self, series_uid):
         with torch.no_grad():
             ct = getCt(series_uid)
 
-            output_ary = np.zeros_like(ct.ary, dtype=np.float32)
+            output_a = np.zeros_like(ct.hu_a, dtype=np.float32)
 
             seg_dl = self.initSegmentationDl(series_uid)
             for batch_tup in seg_dl:
-                input_tensor = batch_tup[0]
+                input_t = batch_tup[0]
                 ndx_list = batch_tup[6]
 
-                input_devtensor = input_tensor.to(self.device)
-                prediction_devtensor = self.seg_model(input_devtensor)
+                input_g = input_t.to(self.device)
+                prediction_g = self.seg_model(input_g)
 
                 for i, sample_ndx in enumerate(ndx_list):
-                    output_ary[sample_ndx] = prediction_devtensor[i].cpu().numpy()
+                    output_a[sample_ndx] = prediction_g[i].cpu().numpy()
 
-            mask_ary = output_ary > 0.5
-            clean_ary = morph.binary_erosion(mask_ary, iterations=1)
-            clean_ary = morph.binary_dilation(clean_ary, iterations=2)
+            mask_a = output_a > 0.5
+            clean_a = morph.binary_erosion(mask_a, iterations=1)
+            clean_a = morph.binary_dilation(clean_a, iterations=2)
 
-        return ct, output_ary, mask_ary, clean_ary
+        return ct, output_a, mask_a, clean_a
 
-    def clusterSegmentationOutput(self, series_uid,  ct, clean_ary):
-        noduleLabel_ary, nodule_count = measure.label(clean_ary)
+    def clusterSegmentationOutput(self, series_uid,  ct, clean_a):
+        noduleLabel_a, nodule_count = measure.label(clean_a)
         centerIrc_list = measure.center_of_mass(
-            ct.ary + 1001,
-            labels=noduleLabel_ary,
+            ct.hu_a + 1001,
+            labels=noduleLabel_a,
             index=list(range(1, nodule_count+1)),
         )
 
         # n = 1298
         # log.debug([
-        #     (noduleLabel_ary == n).sum(),
-        #     np.where(noduleLabel_ary == n),
+        #     (noduleLabel_a == n).sum(),
+        #     np.where(noduleLabel_a == n),
         #
-        #     ct.ary[noduleLabel_ary == n].sum(),
-        #     (ct.ary + 1000)[noduleLabel_ary == n].sum(),
+        #     ct.hu_a[noduleLabel_a == n].sum(),
+        #     (ct.hu_a + 1000)[noduleLabel_a == n].sum(),
         # ])
 
-        if nodule_count < 2:
-            centerIrc_list = [centerIrc_list]
+        # if nodule_count == 1:
+        #     centerIrc_list = [centerIrc_list]
 
         noduleInfo_list = []
         for i, center_irc in enumerate(centerIrc_list):

+ 58 - 311
p2ch12/dsets.py

@@ -9,15 +9,12 @@ import random
 from collections import namedtuple
 
 import SimpleITK as sitk
-
 import numpy as np
-import scipy.ndimage.morphology as morph
 
 import torch
 import torch.cuda
-from torch.utils.data import Dataset
-import torch.nn as nn
 import torch.nn.functional as F
+from torch.utils.data import Dataset
 
 from util.disk import getCache
 from util.util import XyzTuple, xyz2irc
@@ -31,7 +28,6 @@ log.setLevel(logging.DEBUG)
 raw_cache = getCache('part2ch12_raw')
 
 NoduleInfoTuple = namedtuple('NoduleInfoTuple', 'isMalignant_bool, diameter_mm, series_uid, center_xyz')
-MaskTuple = namedtuple('MaskTuple', 'air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask, ben_mask, mal_mask')
 
 @functools.lru_cache(1)
 def getNoduleInfoList(requireDataOnDisk_bool=True):
@@ -77,168 +73,57 @@ def getNoduleInfoList(requireDataOnDisk_bool=True):
     return noduleInfo_list
 
 class Ct(object):
-    def __init__(self, series_uid, buildMasks_bool=True):
+    def __init__(self, series_uid):
         mhd_path = glob.glob('data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
 
         ct_mhd = sitk.ReadImage(mhd_path)
-        ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
+        ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
 
         # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
         # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
         # This gets rid of negative density stuff used to indicate out-of-FOV
-        ct_ary[ct_ary < -1000] = -1000
+        ct_a[ct_a < -1000] = -1000
 
         # This nukes any weird hotspots and clamps bone down
-        ct_ary[ct_ary > 1000] = 1000
+        ct_a[ct_a > 1000] = 1000
 
         self.series_uid = series_uid
-        self.ary = ct_ary
+        self.hu_a = ct_a
 
         self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
         self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
         self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
 
-        noduleInfo_list = getNoduleInfoList()
-        self.benignInfo_list = [ni_tup
-                                for ni_tup in noduleInfo_list
-                                    if not ni_tup.isMalignant_bool
-                                        and ni_tup.series_uid == self.series_uid]
-        self.benign_mask = self.buildAnnotationMask(self.benignInfo_list)[0]
-        self.benign_indexes = sorted(set(self.benign_mask.nonzero()[0]))
-
-        self.malignantInfo_list = [ni_tup
-                                   for ni_tup in noduleInfo_list
-                                        if ni_tup.isMalignant_bool
-                                            and ni_tup.series_uid == self.series_uid]
-        self.malignant_mask = self.buildAnnotationMask(self.malignantInfo_list)[0]
-        self.malignant_indexes = sorted(set(self.malignant_mask.nonzero()[0]))
-
-    def buildAnnotationMask(self, noduleInfo_list, threshold_gcc = -500):
-        boundingBox_ary = np.zeros_like(self.ary, dtype=np.bool)
-
-        for noduleInfo_tup in noduleInfo_list:
-            center_irc = xyz2irc(
-                noduleInfo_tup.center_xyz,
-                self.origin_xyz,
-                self.vxSize_xyz,
-                self.direction_tup,
-            )
-            ci = int(center_irc.index)
-            cr = int(center_irc.row)
-            cc = int(center_irc.col)
-
-            index_radius = 2
-            try:
-                while self.ary[ci + index_radius, cr, cc] > threshold_gcc and \
-                        self.ary[ci - index_radius, cr, cc] > threshold_gcc:
-                    index_radius += 1
-            except IndexError:
-                index_radius -= 1
-
-            row_radius = 2
-            try:
-                while self.ary[ci, cr + row_radius, cc] > threshold_gcc and \
-                        self.ary[ci, cr - row_radius, cc] > threshold_gcc:
-                    row_radius += 1
-            except IndexError:
-                row_radius -= 1
-
-            col_radius = 2
-            try:
-                while self.ary[ci, cr, cc + col_radius] > threshold_gcc and \
-                        self.ary[ci, cr, cc - col_radius] > threshold_gcc:
-                    col_radius += 1
-            except IndexError:
-                col_radius -= 1
-
-            # assert index_radius > 0, repr([noduleInfo_tup.center_xyz, center_irc, self.ary[ci, cr, cc]])
-            # assert row_radius > 0
-            # assert col_radius > 0
-
-
-            slice_tup = (
-                slice(ci - index_radius, ci + index_radius + 1),
-                slice(cr - row_radius, cr + row_radius + 1),
-                slice(cc - col_radius, cc + row_radius + 1),
-            )
-            boundingBox_ary[slice_tup] = True
-
-        thresholded_ary = boundingBox_ary & (self.ary > threshold_gcc)
-        mask_ary = morph.binary_dilation(thresholded_ary, iterations=2)
-
-        return mask_ary, thresholded_ary, boundingBox_ary
-
-    def build2dLungMask(self, mask_ndx, threshold_gcc = -300):
-        dense_mask = self.ary[mask_ndx] > threshold_gcc
-        denoise_mask = morph.binary_closing(dense_mask, iterations=2)
-        tissue_mask = morph.binary_opening(denoise_mask, iterations=10)
-        body_mask = morph.binary_fill_holes(tissue_mask)
-        air_mask = morph.binary_fill_holes(body_mask & ~tissue_mask)
-
-        lung_mask = morph.binary_dilation(air_mask, iterations=2)
-
-        ben_mask = denoise_mask & air_mask
-        ben_mask = morph.binary_dilation(ben_mask, iterations=2)
-        ben_mask &= ~self.malignant_mask[mask_ndx]
-
-        mal_mask = self.malignant_mask[mask_ndx]
-
-        return MaskTuple(
-            air_mask,
-            lung_mask,
-            dense_mask,
-            denoise_mask,
-            tissue_mask,
-            body_mask,
-            ben_mask,
-            mal_mask,
-        )
-
-    def build3dLungMask(self):
-        air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask, ben_mask, mal_mask = mask_list = \
-            [np.zeros_like(self.ary, dtype=np.bool) for _ in range(7)]
-
-        for mask_ndx in range(self.ary.shape[0]):
-            for i, mask_ary in enumerate(self.build2dLungMask(mask_ndx)):
-                mask_list[i][mask_ndx] = mask_ary
-
-        return MaskTuple(air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask, ben_mask, mal_mask)
-
-
     def getRawNodule(self, center_xyz, width_irc):
         center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
 
         slice_list = []
         for axis, center_val in enumerate(center_irc):
-            try:
-                start_ndx = int(round(center_val - width_irc[axis]/2))
-            except:
-                log.debug([center_val, width_irc, center_xyz, center_irc])
-                raise
+            start_ndx = int(round(center_val - width_irc[axis]/2))
             end_ndx = int(start_ndx + width_irc[axis])
 
-            assert center_val >= 0 and center_val < self.ary.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
+            assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
 
             if start_ndx < 0:
                 # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
-                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
+                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
                 start_ndx = 0
                 end_ndx = int(width_irc[axis])
 
-            if end_ndx > self.ary.shape[axis]:
+            if end_ndx > self.hu_a.shape[axis]:
                 # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
-                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
-                end_ndx = self.ary.shape[axis]
-                start_ndx = int(self.ary.shape[axis] - width_irc[axis])
+                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
+                end_ndx = self.hu_a.shape[axis]
+                start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])
 
             slice_list.append(slice(start_ndx, end_ndx))
 
-        ct_chunk = self.ary[tuple(slice_list)]
+        ct_chunk = self.hu_a[tuple(slice_list)]
 
         return ct_chunk, center_irc
 
-ctCache_depth = 5
-@functools.lru_cache(ctCache_depth, typed=True)
+
+@functools.lru_cache(1, typed=True)
 def getCt(series_uid):
     return Ct(series_uid)
 
@@ -248,11 +133,6 @@ def getCtRawNodule(series_uid, center_xyz, width_irc):
     ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
     return ct_chunk, center_irc
 
-@raw_cache.memoize(typed=True)
-def getCtSampleSize(series_uid):
-    ct = Ct(series_uid, buildMasks_bool=False)
-    return len(ct.benign_indexes)
-
 def getCtAugmentedNodule(
         augmentation_dict,
         series_uid, center_xyz, width_irc,
@@ -263,63 +143,65 @@ def getCtAugmentedNodule(
         ct = getCt(series_uid)
         ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
 
-    ct_tensor = torch.tensor(ct_chunk).unsqueeze(0).unsqueeze(0).to(torch.float32)
+    ct_t = torch.tensor(ct_chunk).unsqueeze(0).unsqueeze(0).to(torch.float32)
 
-    transform_tensor = torch.eye(4).to(torch.float64)
+    transform_t = torch.eye(4).to(torch.float64)
+    # ... <1>
 
     for i in range(3):
         if 'flip' in augmentation_dict:
             if random.random() > 0.5:
-                transform_tensor[i,i] *= -1
+                transform_t[i,i] *= -1
 
         if 'offset' in augmentation_dict:
             offset_float = augmentation_dict['offset']
             random_float = (random.random() * 2 - 1)
-            transform_tensor[3,i] = offset_float * random_float
+            transform_t[3,i] = offset_float * random_float
 
         if 'scale' in augmentation_dict:
             scale_float = augmentation_dict['scale']
             random_float = (random.random() * 2 - 1)
-            transform_tensor[i,i] *= 1.0 + scale_float * random_float
+            transform_t[i,i] *= 1.0 + scale_float * random_float
+
 
     if 'rotate' in augmentation_dict:
         angle_rad = random.random() * math.pi * 2
         s = math.sin(angle_rad)
         c = math.cos(angle_rad)
 
-        rotation_tensor = torch.tensor([
+        rotation_t = torch.tensor([
             [c, -s, 0, 0],
             [s, c, 0, 0],
             [0, 0, 1, 0],
             [0, 0, 0, 1],
         ], dtype=torch.float64)
 
-        transform_tensor @= rotation_tensor
+        transform_t @= rotation_t
 
-    affine_tensor = F.affine_grid(
-            transform_tensor[:3].unsqueeze(0).to(torch.float32),
-            ct_tensor.size(),
+    affine_t = F.affine_grid(
+            transform_t[:3].unsqueeze(0).to(torch.float32),
+            ct_t.size(),
         )
 
     augmented_chunk = F.grid_sample(
-            ct_tensor,
-            affine_tensor,
+            ct_t,
+            affine_t,
             padding_mode='border'
         ).to('cpu')
 
     if 'noise' in augmentation_dict:
-        noise_tensor = torch.randn_like(augmented_chunk)
-        noise_tensor *= augmentation_dict['noise']
+        noise_t = torch.randn_like(augmented_chunk)
+        noise_t *= augmentation_dict['noise']
 
-        augmented_chunk += noise_tensor
+        augmented_chunk += noise_t
 
     return augmented_chunk[0], center_irc
 
 
 class LunaDataset(Dataset):
     def __init__(self,
-                 test_stride=0,
-                 isTestSet_bool=None,
+                 val_stride=0,
+                 isValSet_bool=None,
                  series_uid=None,
                  sortby_str='random',
                  ratio_int=0,
@@ -337,25 +219,20 @@ class LunaDataset(Dataset):
             self.use_cache = True
 
         if series_uid:
-            self.series_list = [series_uid]
-        else:
-            self.series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
+            self.noduleInfo_list = [x for x in self.noduleInfo_list if x.series_uid == series_uid]
 
-        if isTestSet_bool:
-            assert test_stride > 0, test_stride
-            self.series_list = self.series_list[::test_stride]
-            assert self.series_list
-        elif test_stride > 0:
-            del self.series_list[::test_stride]
-            assert self.series_list
-
-        series_set = set(self.series_list)
-        self.noduleInfo_list = [x for x in self.noduleInfo_list if x.series_uid in series_set]
+        if isValSet_bool:
+            assert val_stride > 0, val_stride
+            self.noduleInfo_list = self.noduleInfo_list[::val_stride]
+            assert self.noduleInfo_list
+        elif val_stride > 0:
+            del self.noduleInfo_list[::val_stride]
+            assert self.noduleInfo_list
 
         if sortby_str == 'random':
             random.shuffle(self.noduleInfo_list)
         elif sortby_str == 'series_uid':
-            self.noduleInfo_list.sort(key=lambda x: (x[2], x[3])) # sorting by series_uid, center_xyz)
+            self.noduleInfo_list.sort(key=lambda x: (x.series_uid, x.center_xyz))
         elif sortby_str == 'malignancy_size':
             pass
         else:
@@ -367,7 +244,7 @@ class LunaDataset(Dataset):
         log.info("{!r}: {} {} samples, {} ben, {} mal, {} ratio".format(
             self,
             len(self.noduleInfo_list),
-            "testing" if isTestSet_bool else "training",
+            "validation" if isValSet_bool else "training",
             len(self.benign_list),
             len(self.malignant_list),
             '{}:1'.format(self.ratio_int) if self.ratio_int else 'unbalanced'
@@ -380,10 +257,10 @@ class LunaDataset(Dataset):
 
     def __len__(self):
         if self.ratio_int:
-            # return 20000
+            return 20000
             return 200000
         else:
-            return len(self.noduleInfo_list)
+            return len(self.noduleInfo_list) // 20
 
     def __getitem__(self, ndx):
         if self.ratio_int:
@@ -391,13 +268,15 @@ class LunaDataset(Dataset):
 
             if ndx % (self.ratio_int + 1):
                 benign_ndx = ndx - 1 - malignant_ndx
-                nodule_tup = self.benign_list[benign_ndx % len(self.benign_list)]
+                benign_ndx %= len(self.benign_list)
+                nodule_tup = self.benign_list[benign_ndx]
             else:
-                nodule_tup = self.malignant_list[malignant_ndx % len(self.malignant_list)]
+                malignant_ndx %= len(self.malignant_list)
+                nodule_tup = self.malignant_list[malignant_ndx]
         else:
             nodule_tup = self.noduleInfo_list[ndx]
 
-        width_irc = (24, 48, 48)
+        width_irc = (32, 48, 48)
 
         if self.augmentation_dict:
             nodule_t, center_irc = getCtAugmentedNodule(
@@ -408,162 +287,30 @@ class LunaDataset(Dataset):
                 self.use_cache,
             )
         elif self.use_cache:
-            nodule_ary, center_irc = getCtRawNodule(
+            nodule_a, center_irc = getCtRawNodule(
                 nodule_tup.series_uid,
                 nodule_tup.center_xyz,
                 width_irc,
             )
-            nodule_t = torch.from_numpy(nodule_ary).to(torch.float32)
+            nodule_t = torch.from_numpy(nodule_a).to(torch.float32)
             nodule_t = nodule_t.unsqueeze(0)
         else:
             ct = getCt(nodule_tup.series_uid)
-            nodule_ary, center_irc = ct.getRawNodule(
+            nodule_a, center_irc = ct.getRawNodule(
                 nodule_tup.center_xyz,
                 width_irc,
             )
-            nodule_t = torch.from_numpy(nodule_ary).to(torch.float32)
+            nodule_t = torch.from_numpy(nodule_a).to(torch.float32)
             nodule_t = nodule_t.unsqueeze(0)
 
-        malignant_tensor = torch.tensor([
+        malignant_t = torch.tensor([
                 not nodule_tup.isMalignant_bool,
                 nodule_tup.isMalignant_bool
             ],
             dtype=torch.long,
         )
 
-        # log.debug([type(center_irc), center_irc])
-
-        return nodule_t, malignant_tensor, nodule_tup.series_uid, torch.tensor(center_irc)
-
-
-
-
-class Luna2dSegmentationDataset(Dataset):
-    def __init__(self,
-                 test_stride=0,
-                 isTestSet_bool=None,
-                 series_uid=None,
-                 contextSlices_count=2,
-                 augmentation_dict=None,
-                 fullCt_bool=False,
-            ):
-        self.contextSlices_count = contextSlices_count
-        self.augmentation_dict = augmentation_dict
-
-        if series_uid:
-            self.series_list = [series_uid]
-        else:
-            self.series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
-
-        if isTestSet_bool:
-            assert test_stride > 0, test_stride
-            self.series_list = self.series_list[::test_stride]
-            assert self.series_list
-        elif test_stride > 0:
-            del self.series_list[::test_stride]
-            assert self.series_list
-
-        self.sample_list = []
-        for series_uid in self.series_list:
-            if fullCt_bool:
-                self.sample_list.extend([(series_uid, ct_ndx) for ct_ndx in range(getCt(series_uid).ary.shape[0])])
-            else:
-                self.sample_list.extend([(series_uid, ct_ndx) for ct_ndx in range(getCtSampleSize(series_uid))])
-
-        log.info("{!r}: {} {} series, {} slices".format(
-            self,
-            len(self.series_list),
-            {None: 'general', True: 'testing', False: 'training'}[isTestSet_bool],
-            len(self.sample_list),
-        ))
-
-    def __len__(self):
-        return len(self.sample_list) #// 100
-
-    def __getitem__(self, ndx):
-        if isinstance(ndx, int):
-            series_uid, sample_ndx = self.sample_list[ndx % len(self.sample_list)]
-            ct = getCt(series_uid)
-            ct_ndx = self.sample_list[sample_ndx][1]
-            useAugmentation_bool = False
-        else:
-            series_uid, ct_ndx, useAugmentation_bool = ndx
-            ct = getCt(series_uid)
-
-        ct_tensor = torch.zeros((self.contextSlices_count * 2 + 1 + 1, 512, 512))
-
-        start_ndx = ct_ndx - self.contextSlices_count
-        end_ndx = ct_ndx + self.contextSlices_count + 1
-        for i, context_ndx in enumerate(range(start_ndx, end_ndx)):
-            context_ndx = max(context_ndx, 0)
-            context_ndx = min(context_ndx, ct.ary.shape[0] - 1)
-
-            ct_tensor[i] = torch.from_numpy(ct.ary[context_ndx].astype(np.float32))
-        ct_tensor /= 1000
-
-        mask_tup = ct.build2dLungMask(ct_ndx)
+        return nodule_t, malignant_t, nodule_tup.series_uid, torch.tensor(center_irc)
 
-        ct_tensor[-1] = torch.from_numpy(mask_tup.body_mask.astype(np.float32))
 
-        nodule_tensor = torch.from_numpy(
-            (mask_tup.mal_mask | mask_tup.ben_mask).astype(np.float32)
-        ).unsqueeze(0)
-        ben_tensor = torch.from_numpy(mask_tup.ben_mask.astype(np.float32))
-        mal_tensor = torch.from_numpy(mask_tup.mal_mask.astype(np.float32))
-        label_int = mal_tensor.max() + ben_tensor.max() * 2
-
-        if self.augmentation_dict and useAugmentation_bool:
-            if 'rotate' in self.augmentation_dict:
-                if random.random() > 0.5:
-                    ct_tensor = ct_tensor.rot90(1, [1, 2])
-                    nodule_tensor = nodule_tensor.rot90(1, [1, 2])
-
-            if 'flip' in self.augmentation_dict:
-                dims = [d+1 for d in range(2) if random.random() > 0.5]
-
-                if dims:
-                    ct_tensor = ct_tensor.flip(dims)
-                    nodule_tensor = nodule_tensor.flip(dims)
-
-            if 'noise' in self.augmentation_dict:
-                noise_tensor = torch.randn_like(ct_tensor)
-                noise_tensor *= self.augmentation_dict['noise']
-
-                ct_tensor += noise_tensor
-        return ct_tensor, nodule_tensor, label_int, ben_tensor, mal_tensor, ct.series_uid, ct_ndx
-
-
-class TrainingLuna2dSegmentationDataset(Luna2dSegmentationDataset):
-    def __init__(self, *args, batch_size=80, **kwargs):
-        self.needsShuffle_bool = True
-        self.batch_size = batch_size
-        # self.rotate_frac = 0.5 * len(self.series_list) / len(self)
-        super().__init__(*args, **kwargs)
-
-    def __len__(self):
-        return 50000
-
-    def __getitem__(self, ndx):
-        if self.needsShuffle_bool:
-            random.shuffle(self.series_list)
-            self.needsShuffle_bool = False
-
-        if isinstance(ndx, int):
-            if ndx % self.batch_size == 0:
-                self.series_list.append(self.series_list.pop(0))
-
-            series_uid = self.series_list[ndx % ctCache_depth]
-            ct = getCt(series_uid)
-
-            if ndx % 3 == 0:
-                ct_ndx = random.choice(ct.malignant_indexes or ct.benign_indexes)
-            elif ndx % 3 == 1:
-                ct_ndx = random.choice(ct.benign_indexes)
-            elif ndx % 3 == 2:
-                ct_ndx = random.choice(list(range(ct.ary.shape[0])))
-
-            useAugmentation_bool = True
-        else:
-            series_uid, ct_ndx, useAugmentation_bool = ndx
 
-        return super().__getitem__((series_uid, ct_ndx, useAugmentation_bool))

+ 53 - 14
p2ch12/model.py

@@ -3,39 +3,78 @@ import math
 from torch import nn as nn
 
 from util.logconf import logging
-from util.unet import UNet
 
 log = logging.getLogger(__name__)
 # log.setLevel(logging.WARN)
 # log.setLevel(logging.INFO)
 log.setLevel(logging.DEBUG)
 
-class UNetWrapper(nn.Module):
-    def __init__(self, **kwargs):
+
+class LunaModel(nn.Module):
+    def __init__(self, in_channels=1, conv_channels=8):
         super().__init__()
 
-        self.batchnorm = nn.BatchNorm2d(kwargs['in_channels'])
-        self.unet = UNet(**kwargs)
-        self.final = nn.Sigmoid()
+        self.tail_batchnorm = nn.BatchNorm3d(1)
+
+        self.block1 = LunaBlock(in_channels, conv_channels)
+        self.block2 = LunaBlock(conv_channels, conv_channels * 2)
+        self.block3 = LunaBlock(conv_channels * 2, conv_channels * 4)
+        self.block4 = LunaBlock(conv_channels * 4, conv_channels * 8)
+
+        self.head_linear = nn.Linear(1152, 2)
+        self.head_softmax = nn.Softmax(dim=1)
 
+        self._init_weights()
+
+    # see also https://github.com/pytorch/pytorch/issues/18182
+    def _init_weights(self):
         for m in self.modules():
             if type(m) in {
-                nn.Conv2d,
+                nn.Linear,
                 nn.Conv3d,
+                nn.Conv2d,
                 nn.ConvTranspose2d,
                 nn.ConvTranspose3d,
-                nn.Linear,
             }:
-                nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='leaky_relu', a=0)
+                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out', nonlinearity='relu')
                 if m.bias is not None:
                     fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                     bound = 1 / math.sqrt(fan_out)
                     nn.init.normal_(m.bias, -bound, bound)
 
 
-    def forward(self, input):
-        bn_output = self.batchnorm(input)
-        un_output = self.unet(bn_output)
-        fn_output = self.final(un_output)
+    def forward(self, input_batch):
+        bn_output = self.tail_batchnorm(input_batch)
+
+        block_out = self.block1(bn_output)
+        block_out = self.block2(block_out)
+        block_out = self.block3(block_out)
+        block_out = self.block4(block_out)
+
+        conv_flat = block_out.view(
+            block_out.size(0),
+            -1,
+        )
+        linear_output = self.head_linear(conv_flat)
+
+        return linear_output, self.head_softmax(linear_output)
+
+
+class LunaBlock(nn.Module):
+    def __init__(self, in_channels, conv_channels):
+        super().__init__()
+
+        self.conv1 = nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True)
+        self.relu1 = nn.ReLU(inplace=True)
+        self.conv2 = nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=True)
+        self.relu2 = nn.ReLU(inplace=True)
+
+        self.maxpool = nn.MaxPool3d(2, 2)
+
+    def forward(self, input_batch):
+        block_out = self.conv1(input_batch)
+        block_out = self.relu1(block_out)
+        block_out = self.conv2(block_out)
+        block_out = self.relu2(block_out)
 
-        return fn_output
+        return self.maxpool(block_out)

+ 5 - 14
p2ch12/prepcache.py

@@ -9,9 +9,9 @@ from torch.optim import SGD
 from torch.utils.data import DataLoader
 
 from util.util import enumerateWithEstimate
-from .dsets import LunaDataset, getCtSampleSize
+from .dsets import LunaDataset
 from util.logconf import logging
-# from .model import LunaModel
+from .model import LunaModel
 
 log = logging.getLogger(__name__)
 # log.setLevel(logging.WARN)
@@ -36,11 +36,6 @@ class LunaPrepCacheApp(object):
             default=8,
             type=int,
         )
-        # parser.add_argument('--scaled',
-        #     help="Scale the CT chunks to square voxels.",
-        #     default=False,
-        #     action='store_true',
-        # )
 
         self.cli_args = parser.parse_args(sys_argv)
 
@@ -60,13 +55,9 @@ class LunaPrepCacheApp(object):
             "Stuffing cache",
             start_ndx=self.prep_dl.num_workers,
         )
-        for batch_ndx, batch_tup in batch_iter:
-            _nodule_tensor, _malignant_tensor, series_list, _center_list = batch_tup
-            for series_uid in sorted(set(series_list)):
-                getCtSampleSize(series_uid)
-            # input_tensor, label_tensor, _series_list, _start_list = batch_tup
-
+        for _ in batch_iter:
+            pass
 
 
 if __name__ == '__main__':
-    sys.exit(LunaPrepCacheApp().main() or 0)
+    LunaPrepCacheApp().main()

+ 1 - 1
p2ch12/screencts.py

@@ -29,7 +29,7 @@ class LunaScreenCtDataset(Dataset):
     def __getitem__(self, ndx):
         series_uid = self.series_list[ndx]
         ct = getCt(series_uid)
-        mid_ndx = ct.ary.shape[0] // 2
+        mid_ndx = ct.hu_a.shape[0] // 2
 
         air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask, altben_mask = ct.build2dLungMask(mid_ndx)
 

+ 77 - 69
p2ch12/train_cls.py

@@ -23,7 +23,7 @@ log = logging.getLogger(__name__)
 # log.setLevel(logging.INFO)
 log.setLevel(logging.DEBUG)
 
-# Used for computeBatchLoss and logMetrics to index into metrics_tensor/metrics_ary
+# Used for computeBatchLoss and logMetrics to index into metrics_t/metrics_a
 METRICS_LABEL_NDX=0
 METRICS_PRED_NDX=1
 METRICS_LOSS_NDX=2
@@ -101,7 +101,7 @@ class LunaTrainingApp(object):
 
         self.totalTrainingSamples_count = 0
         self.trn_writer = None
-        self.tst_writer = None
+        self.val_writer = None
 
         self.augmentation_dict = {}
         if self.cli_args.augmented or self.cli_args.augment_flip:
@@ -136,8 +136,8 @@ class LunaTrainingApp(object):
 
     def initTrainDl(self):
         train_ds = LunaDataset(
-            test_stride=10,
-            isTestSet_bool=False,
+            val_stride=10,
+            isValSet_bool=False,
             ratio_int=int(self.cli_args.balanced),
             augmentation_dict=self.augmentation_dict,
         )
@@ -151,35 +151,34 @@ class LunaTrainingApp(object):
 
         return train_dl
 
-    def initTestDl(self):
-        test_ds = LunaDataset(
-            test_stride=10,
-            isTestSet_bool=True,
+    def initValDl(self):
+        val_ds = LunaDataset(
+            val_stride=10,
+            isValSet_bool=True,
         )
 
-        test_dl = DataLoader(
-            test_ds,
+        val_dl = DataLoader(
+            val_ds,
             batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
             num_workers=self.cli_args.num_workers,
             pin_memory=self.use_cuda,
         )
 
-        return test_dl
+        return val_dl
 
     def initTensorboardWriters(self):
         if self.trn_writer is None:
             log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
 
             self.trn_writer = SummaryWriter(log_dir=log_dir + '_trn_cls_' + self.cli_args.comment)
-            self.tst_writer = SummaryWriter(log_dir=log_dir + '_tst_cls_' + self.cli_args.comment)
-# eng::tb_writer[]
+            self.val_writer = SummaryWriter(log_dir=log_dir + '_val_cls_' + self.cli_args.comment)
 
 
     def main(self):
         log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
 
         train_dl = self.initTrainDl()
-        test_dl = self.initTestDl()
+        val_dl = self.initValDl()
 
         best_score = 0.0
 
@@ -189,29 +188,32 @@ class LunaTrainingApp(object):
                 epoch_ndx,
                 self.cli_args.epochs,
                 len(train_dl),
-                len(test_dl),
+                len(val_dl),
                 self.cli_args.batch_size,
                 (torch.cuda.device_count() if self.use_cuda else 1),
             ))
 
-            trnMetrics_tensor = self.doTraining(epoch_ndx, train_dl)
-            self.logMetrics(epoch_ndx, 'trn', trnMetrics_tensor)
+            trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
+            self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)
 
-            tstMetrics_tensor = self.doTesting(epoch_ndx, test_dl)
-            score = self.logMetrics(epoch_ndx, 'tst', tstMetrics_tensor)
+            valMetrics_t = self.doValidation(epoch_ndx, val_dl)
+            score = self.logMetrics(epoch_ndx, 'val', valMetrics_t)
             best_score = max(score, best_score)
 
             self.saveModel('cls', epoch_ndx, score == best_score)
 
         if hasattr(self, 'trn_writer'):
             self.trn_writer.close()
-            self.tst_writer.close()
+            self.val_writer.close()
 
 
     def doTraining(self, epoch_ndx, train_dl):
         self.model.train()
         train_dl.dataset.shuffleSamples()
-        trainingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(train_dl.dataset)).to(self.device)
+        trnMetrics_g = torch.zeros(
+                METRICS_SIZE,
+                len(train_dl.dataset),
+            ).to(self.device)
         batch_iter = enumerateWithEstimate(
             train_dl,
             "E{} Training".format(epoch_ndx),
@@ -224,67 +226,68 @@ class LunaTrainingApp(object):
                 batch_ndx,
                 batch_tup,
                 train_dl.batch_size,
-                trainingMetrics_devtensor
+                trnMetrics_g
             )
 
             loss_var.backward()
             self.optimizer.step()
             del loss_var
 
-        self.totalTrainingSamples_count += trainingMetrics_devtensor.size(1)
+        self.totalTrainingSamples_count += trnMetrics_g.size(1)
 
-        return trainingMetrics_devtensor.to('cpu')
+        return trnMetrics_g.to('cpu')
 
 
-    def doTesting(self, epoch_ndx, test_dl):
+    def doValidation(self, epoch_ndx, val_dl):
         with torch.no_grad():
             self.model.eval()
-            testingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset)).to(self.device)
+            valMetrics_g = torch.zeros(
+                    METRICS_SIZE,
+                    len(val_dl.dataset),
+                ).to(self.device)
             batch_iter = enumerateWithEstimate(
-                test_dl,
-                "E{} Testing ".format(epoch_ndx),
-                start_ndx=test_dl.num_workers,
+                val_dl,
+                "E{} Validation ".format(epoch_ndx),
+                start_ndx=val_dl.num_workers,
             )
             for batch_ndx, batch_tup in batch_iter:
-                self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_devtensor)
-
-        return testingMetrics_devtensor.to('cpu')
+                self.computeBatchLoss(
+                    batch_ndx,
+                    batch_tup,
+                    val_dl.batch_size,
+                    valMetrics_g,
+                )
 
+        return valMetrics_g.to('cpu')
 
 
-    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_devtensor):
-        input_tensor, label_tensor, _series_list, _center_list = batch_tup
 
-        input_devtensor = input_tensor.to(self.device, non_blocking=True)
-        label_devtensor = label_tensor.to(self.device, non_blocking=True)
+    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g):
+        input_t, label_t, _series_list, _center_list = batch_tup
 
-        logits_devtensor, probability_devtensor = self.model(input_devtensor)
+        input_g = input_t.to(self.device, non_blocking=True)
+        label_g = label_t.to(self.device, non_blocking=True)
 
-        # log.debug(['input', input_devtensor.min().item(), input_devtensor.max().item()])
-        # log.debug(['label', label_devtensor.min().item(), label_devtensor.max().item()])
-        # log.debug(['logits', logits_devtensor.min().item(), logits_devtensor.max().item()])
-        # log.debug(['probability', probability_devtensor.min().item(), probability_devtensor.max().item()])
+        logits_g, probability_g = self.model(input_g)
 
         loss_func = nn.CrossEntropyLoss(reduction='none')
-        loss_devtensor = loss_func(logits_devtensor, label_devtensor[:,1])
-
-        # log.debug(['loss', loss_devtensor.min().item(), loss_devtensor.max().item()])
+        loss_g = loss_func(logits_g, label_g[:,1])
 
         start_ndx = batch_ndx * batch_size
-        end_ndx = start_ndx + label_tensor.size(0)
+        end_ndx = start_ndx + label_t.size(0)
 
-        metrics_devtensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_devtensor[:,1]
-        metrics_devtensor[METRICS_PRED_NDX, start_ndx:end_ndx] = probability_devtensor[:,1]
-        metrics_devtensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_devtensor
+        metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_g[:,1]
+        metrics_g[METRICS_PRED_NDX, start_ndx:end_ndx] = probability_g[:,1]
+        metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_g
 
-        return loss_devtensor.mean()
+        return loss_g.mean()
 
 
     def logMetrics(
             self,
             epoch_ndx,
             mode_str,
-            metrics_tensor,
+            metrics_t,
     ):
         self.initTensorboardWriters()
         log.info("E{} {}".format(
@@ -292,11 +295,11 @@ class LunaTrainingApp(object):
             type(self).__name__,
         ))
 
-        metrics_ary = metrics_tensor.cpu().detach().numpy()
-#         assert np.isfinite(metrics_ary).all()
+        metrics_a = metrics_t.cpu().detach().numpy()
+#         assert np.isfinite(metrics_a).all()
 
-        benLabel_mask = metrics_ary[METRICS_LABEL_NDX] <= 0.5
-        benPred_mask = metrics_ary[METRICS_PRED_NDX] <= 0.5
+        benLabel_mask = metrics_a[METRICS_LABEL_NDX] <= 0.5
+        benPred_mask = metrics_a[METRICS_PRED_NDX] <= 0.5
 
         malLabel_mask = ~benLabel_mask
         malPred_mask = ~benPred_mask
@@ -311,18 +314,21 @@ class LunaTrainingApp(object):
         falseNeg_count = malLabel_count - malCorrect_count
 
         metrics_dict = {}
-        metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
-        metrics_dict['loss/ben'] = metrics_ary[METRICS_LOSS_NDX, benLabel_mask].mean()
-        metrics_dict['loss/mal'] = metrics_ary[METRICS_LOSS_NDX, malLabel_mask].mean()
+        metrics_dict['loss/all'] = metrics_a[METRICS_LOSS_NDX].mean()
+        metrics_dict['loss/ben'] = metrics_a[METRICS_LOSS_NDX, benLabel_mask].mean()
+        metrics_dict['loss/mal'] = metrics_a[METRICS_LOSS_NDX, malLabel_mask].mean()
 
-        metrics_dict['correct/all'] = (malCorrect_count + benCorrect_count) / metrics_ary.shape[1] * 100
+        metrics_dict['correct/all'] = (malCorrect_count + benCorrect_count) / metrics_a.shape[1] * 100
         metrics_dict['correct/ben'] = (benCorrect_count) / benLabel_count * 100
         metrics_dict['correct/mal'] = (malCorrect_count) / malLabel_count * 100
 
-        precision = metrics_dict['pr/precision'] = truePos_count / (truePos_count + falsePos_count)
-        recall    = metrics_dict['pr/recall']    = truePos_count / (truePos_count + falseNeg_count)
+        precision = metrics_dict['pr/precision'] = \
+            truePos_count / (truePos_count + falsePos_count)
+        recall = metrics_dict['pr/recall'] = \
+            truePos_count / (truePos_count + falseNeg_count)
 
-        metrics_dict['pr/f1_score'] = 2 * (precision * recall) / (precision + recall)
+        metrics_dict['pr/f1_score'] = 2 * (precision * recall)\
+            / (precision + recall)
 
         log.info(
             ("E{} {:8} "
@@ -340,7 +346,8 @@ class LunaTrainingApp(object):
         log.info(
             ("E{} {:8} "
                  + "{loss/ben:.4f} loss, "
-                 + "{correct/ben:-5.1f}% correct ({benCorrect_count:} of {benLabel_count:})"
+                 + "{correct/ben:-5.1f}% correct "
+                 + "({benCorrect_count:} of {benLabel_count:})"
             ).format(
                 epoch_ndx,
                 mode_str + '_ben',
@@ -352,7 +359,8 @@ class LunaTrainingApp(object):
         log.info(
             ("E{} {:8} "
                  + "{loss/mal:.4f} loss, "
-                 + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
+                 + "{correct/mal:-5.1f}% correct "
+                 + "({malCorrect_count:} of {malLabel_count:})"
             ).format(
                 epoch_ndx,
                 mode_str + '_mal',
@@ -368,27 +376,27 @@ class LunaTrainingApp(object):
 
         writer.add_pr_curve(
             'pr',
-            metrics_ary[METRICS_LABEL_NDX],
-            metrics_ary[METRICS_PRED_NDX],
+            metrics_a[METRICS_LABEL_NDX],
+            metrics_a[METRICS_PRED_NDX],
             self.totalTrainingSamples_count,
         )
 
         bins = [x/50.0 for x in range(51)]
 
-        benHist_mask = benLabel_mask & (metrics_ary[METRICS_PRED_NDX] > 0.01)
-        malHist_mask = malLabel_mask & (metrics_ary[METRICS_PRED_NDX] < 0.99)
+        benHist_mask = benLabel_mask & (metrics_a[METRICS_PRED_NDX] > 0.01)
+        malHist_mask = malLabel_mask & (metrics_a[METRICS_PRED_NDX] < 0.99)
 
         if benHist_mask.any():
             writer.add_histogram(
                 'is_ben',
-                metrics_ary[METRICS_PRED_NDX, benHist_mask],
+                metrics_a[METRICS_PRED_NDX, benHist_mask],
                 self.totalTrainingSamples_count,
                 bins=bins,
             )
         if malHist_mask.any():
             writer.add_histogram(
                 'is_mal',
-                metrics_ary[METRICS_PRED_NDX, malHist_mask],
+                metrics_a[METRICS_PRED_NDX, malHist_mask],
                 self.totalTrainingSamples_count,
                 bins=bins,
             )

+ 186 - 144
p2ch12/train_seg.py

@@ -25,7 +25,7 @@ log = logging.getLogger(__name__)
 # log.setLevel(logging.INFO)
 log.setLevel(logging.DEBUG)
 
-# Used for computeClassificationLoss and logMetrics to index into metrics_tensor/metrics_ary
+# Used for computeClassificationLoss and logMetrics to index into metrics_t/metrics_a
 METRICS_LABEL_NDX = 0
 METRICS_LOSS_NDX = 1
 METRICS_MAL_LOSS_NDX = 2
@@ -108,7 +108,7 @@ class LunaTrainingApp(object):
         self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
         self.totalTrainingSamples_count = 0
         self.trn_writer = None
-        self.tst_writer = None
+        self.val_writer = None
 
         augmentation_dict = {}
         if self.cli_args.augmented or self.cli_args.augment_flip:
@@ -116,7 +116,7 @@ class LunaTrainingApp(object):
         if self.cli_args.augmented or self.cli_args.augment_rotate:
             augmentation_dict['rotate'] = True
         if self.cli_args.augmented or self.cli_args.augment_noise:
-            augmentation_dict['noise'] = 25.0
+            augmentation_dict['noise'] = 0.025
         self.augmentation_dict = augmentation_dict
 
         self.use_cuda = torch.cuda.is_available()
@@ -127,7 +127,15 @@ class LunaTrainingApp(object):
 
 
     def initModel(self):
-        model = UNetWrapper(in_channels=8, n_classes=1, depth=4, wf=3, padding=True, batch_norm=True, up_mode='upconv')
+        model = UNetWrapper(
+            in_channels=8,
+            n_classes=1,
+            depth=4,
+            wf=3,
+            padding=True,
+            batch_norm=True,
+            up_mode='upconv',
+        )
 
         if self.use_cuda:
             if torch.cuda.device_count() > 1:
@@ -142,8 +150,8 @@ class LunaTrainingApp(object):
 
     def initTrainDl(self):
         train_ds = TrainingLuna2dSegmentationDataset(
-            test_stride=10,
-            isTestSet_bool=False,
+            val_stride=10,
+            isValSet_bool=False,
             contextSlices_count=3,
             augmentation_dict=self.augmentation_dict,
         )
@@ -157,34 +165,34 @@ class LunaTrainingApp(object):
 
         return train_dl
 
-    def initTestDl(self):
-        test_ds = Luna2dSegmentationDataset(
-            test_stride=10,
-            isTestSet_bool=True,
+    def initValDl(self):
+        val_ds = Luna2dSegmentationDataset(
+            val_stride=10,
+            isValSet_bool=True,
             contextSlices_count=3,
         )
 
-        test_dl = DataLoader(
-            test_ds,
+        val_dl = DataLoader(
+            val_ds,
             batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
             num_workers=self.cli_args.num_workers,
             pin_memory=self.use_cuda,
         )
 
-        return test_dl
+        return val_dl
 
     def initTensorboardWriters(self):
         if self.trn_writer is None:
             log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
 
             self.trn_writer = SummaryWriter(log_dir=log_dir + '_trn_seg_' + self.cli_args.comment)
-            self.tst_writer = SummaryWriter(log_dir=log_dir + '_tst_seg_' + self.cli_args.comment)
+            self.val_writer = SummaryWriter(log_dir=log_dir + '_val_seg_' + self.cli_args.comment)
 
     def main(self):
         log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
 
         train_dl = self.initTrainDl()
-        test_dl = self.initTestDl()
+        val_dl = self.initValDl()
 
         # self.logModelMetrics(self.model)
 
@@ -194,29 +202,29 @@ class LunaTrainingApp(object):
                 epoch_ndx,
                 self.cli_args.epochs,
                 len(train_dl),
-                len(test_dl),
+                len(val_dl),
                 self.cli_args.batch_size,
                 (torch.cuda.device_count() if self.use_cuda else 1),
             ))
 
-            trainingMetrics_tensor = self.doTraining(epoch_ndx, train_dl)
-            self.logMetrics(epoch_ndx, 'trn', trainingMetrics_tensor)
+            trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
+            self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)
             self.logImages(epoch_ndx, 'trn', train_dl)
-            self.logImages(epoch_ndx, 'tst', test_dl)
+            self.logImages(epoch_ndx, 'val', val_dl)
             # self.logModelMetrics(self.model)
 
-            testingMetrics_tensor = self.doTesting(epoch_ndx, test_dl)
-            score = self.logMetrics(epoch_ndx, 'tst', testingMetrics_tensor)
+            valMetrics_t = self.doValidation(epoch_ndx, val_dl)
+            score = self.logMetrics(epoch_ndx, 'val', valMetrics_t)
             best_score = max(score, best_score)
 
             self.saveModel('seg', epoch_ndx, score == best_score)
 
         if hasattr(self, 'trn_writer'):
             self.trn_writer.close()
-            self.tst_writer.close()
+            self.val_writer.close()
 
     def doTraining(self, epoch_ndx, train_dl):
-        trainingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(train_dl.dataset)).to(self.device)
+        trnMetrics_g = torch.zeros(METRICS_SIZE, len(train_dl.dataset)).to(self.device)
         self.model.train()
 
         batch_iter = enumerateWithEstimate(
@@ -227,193 +235,223 @@ class LunaTrainingApp(object):
         for batch_ndx, batch_tup in batch_iter:
             self.optimizer.zero_grad()
 
-            loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_devtensor)
+            loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trnMetrics_g)
             loss_var.backward()
 
             self.optimizer.step()
             del loss_var
 
-        self.totalTrainingSamples_count += trainingMetrics_devtensor.size(1)
+        self.totalTrainingSamples_count += trnMetrics_g.size(1)
 
-        return trainingMetrics_devtensor.to('cpu')
+        return trnMetrics_g.to('cpu')
 
-    def doTesting(self, epoch_ndx, test_dl):
+    def doValidation(self, epoch_ndx, val_dl):
         with torch.no_grad():
-            testingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset)).to(self.device)
+            valMetrics_g = torch.zeros(METRICS_SIZE, len(val_dl.dataset)).to(self.device)
             self.model.eval()
 
             batch_iter = enumerateWithEstimate(
-                test_dl,
-                "E{} Testing ".format(epoch_ndx),
-                start_ndx=test_dl.num_workers,
+                val_dl,
+                "E{} Validation ".format(epoch_ndx),
+                start_ndx=val_dl.num_workers,
             )
             for batch_ndx, batch_tup in batch_iter:
-                self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_devtensor)
+                self.computeBatchLoss(batch_ndx, batch_tup, val_dl.batch_size, valMetrics_g)
 
-        return testingMetrics_devtensor.to('cpu')
+        return valMetrics_g.to('cpu')
 
-    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_devtensor):
-        input_tensor, label_tensor, label_list, ben_tensor, mal_tensor, _series_list, _start_list = batch_tup
+    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g):
+        input_t, label_t, label_list, ben_t, mal_t, _, _ = batch_tup
 
-        input_devtensor = input_tensor.to(self.device, non_blocking=True)
-        label_devtensor = label_tensor.to(self.device, non_blocking=True)
-        mal_devtensor = mal_tensor.to(self.device, non_blocking=True)
-        ben_devtensor = ben_tensor.to(self.device, non_blocking=True)
+        input_g = input_t.to(self.device, non_blocking=True)
+        label_g = label_t.to(self.device, non_blocking=True)
+        mal_g = mal_t.to(self.device, non_blocking=True)
+        ben_g = ben_t.to(self.device, non_blocking=True)
 
         start_ndx = batch_ndx * batch_size
-        end_ndx = start_ndx + label_tensor.size(0)
+        end_ndx = start_ndx + label_t.size(0)
         intersectionSum = lambda a, b: (a * b).view(a.size(0), -1).sum(dim=1)
 
-        prediction_devtensor = self.model(input_devtensor)
-        diceLoss_devtensor = self.diceLoss(label_devtensor, prediction_devtensor)
+        prediction_g = self.model(input_g)
+        diceLoss_g = self.diceLoss(label_g, prediction_g)
 
         with torch.no_grad():
-            predictionBool_devtensor = (prediction_devtensor > 0.5).to(torch.float32)
+            malLoss_g = self.diceLoss(mal_g, prediction_g * mal_g, p=True)
+            predictionBool_g = (prediction_g > 0.5).to(torch.float32)
 
-            metrics_devtensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_list
-            metrics_devtensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_devtensor
+            metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_list
+            metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_g
+            metrics_g[METRICS_MAL_LOSS_NDX, start_ndx:end_ndx] = malLoss_g
 
-            # benPred_devtensor = predictionBool_devtensor * (1 - mal_devtensor)
-            tp = intersectionSum(    label_devtensor,     predictionBool_devtensor)
-            fn = intersectionSum(    label_devtensor, 1 - predictionBool_devtensor)
-            fp = intersectionSum(1 - label_devtensor,     predictionBool_devtensor)
-            # ls = self.diceLoss(label_devtensor, benPred_devtensor)
+            malPred_g = predictionBool_g * mal_g
+            tp = intersectionSum(    mal_g,       malPred_g)
+            fn = intersectionSum(    mal_g,   1 - malPred_g)
 
-            metrics_devtensor[METRICS_ATP_NDX, start_ndx:end_ndx] = tp
-            metrics_devtensor[METRICS_AFN_NDX, start_ndx:end_ndx] = fn
-            metrics_devtensor[METRICS_AFP_NDX, start_ndx:end_ndx] = fp
-            # metrics_devtensor[METRICS_ALL_LOSS_NDX, start_ndx:end_ndx] = ls
+            metrics_g[METRICS_MTP_NDX, start_ndx:end_ndx] = tp
+            metrics_g[METRICS_MFN_NDX, start_ndx:end_ndx] = fn
 
-            del tp, fn, fp
+            del malPred_g, tp, fn
 
-            malPred_devtensor = predictionBool_devtensor * (1 - ben_devtensor)
-            tp = intersectionSum(    mal_devtensor,       malPred_devtensor)
-            fn = intersectionSum(    mal_devtensor,   1 - malPred_devtensor)
-            fp = intersectionSum(1 - label_devtensor,     malPred_devtensor)
-            ls = self.diceLoss(mal_devtensor, malPred_devtensor)
+            tp = intersectionSum(    label_g,     predictionBool_g)
+            fn = intersectionSum(    label_g, 1 - predictionBool_g)
+            fp = intersectionSum(1 - label_g,     predictionBool_g)
 
-            metrics_devtensor[METRICS_MTP_NDX, start_ndx:end_ndx] = tp
-            metrics_devtensor[METRICS_MFN_NDX, start_ndx:end_ndx] = fn
-            # metrics_devtensor[METRICS_MFP_NDX, start_ndx:end_ndx] = fp
-            metrics_devtensor[METRICS_MAL_LOSS_NDX, start_ndx:end_ndx] = ls
+            metrics_g[METRICS_ATP_NDX, start_ndx:end_ndx] = tp
+            metrics_g[METRICS_AFN_NDX, start_ndx:end_ndx] = fn
+            metrics_g[METRICS_AFP_NDX, start_ndx:end_ndx] = fp
 
-            del malPred_devtensor, tp, fn, fp, ls
+            del tp, fn, fp
 
-        return diceLoss_devtensor.mean()
+        return diceLoss_g.mean()
 
-    # def diceLoss(self, label_devtensor, prediction_devtensor, epsilon=0.01, p=False):
-    def diceLoss(self, label_devtensor, prediction_devtensor, epsilon=1024, p=False):
+    # def diceLoss(self, label_g, prediction_g, epsilon=0.01, p=False):
+    def diceLoss(self, label_g, prediction_g, epsilon=1, p=False):
         sum_dim1 = lambda t: t.view(t.size(0), -1).sum(dim=1)
 
-        diceLabel_devtensor = sum_dim1(label_devtensor)
-        dicePrediction_devtensor = sum_dim1(prediction_devtensor)
-        diceCorrect_devtensor = sum_dim1(prediction_devtensor * label_devtensor)
+        diceLabel_g = sum_dim1(label_g)
+        dicePrediction_g = sum_dim1(prediction_g)
+        diceCorrect_g = sum_dim1(prediction_g * label_g)
+
+        epsilon_g = torch.ones_like(diceCorrect_g) * epsilon
+        diceLoss_g = 1 - (2 * diceCorrect_g + epsilon_g) \
+            / (dicePrediction_g + diceLabel_g + epsilon_g)
 
-        epsilon_devtensor = torch.ones_like(diceCorrect_devtensor) * epsilon
-        diceLoss_devtensor = 1 - (2 * diceCorrect_devtensor + epsilon_devtensor) / (dicePrediction_devtensor + diceLabel_devtensor + epsilon_devtensor)
+        if p and diceLoss_g.mean() < 0:
+            correct_tmp = prediction_g * label_g
 
-        if p:
             log.debug([])
-            log.debug(['diceCorrect_devtensor   ', diceCorrect_devtensor[0].item()])
-            log.debug(['dicePrediction_devtensor', dicePrediction_devtensor[0].item()])
-            log.debug(['diceLabel_devtensor     ', diceLabel_devtensor[0].item()])
-            log.debug(['2*diceCorrect_devtensor ', 2 * diceCorrect_devtensor[0].item()])
-            log.debug(['Prediction + Label      ', dicePrediction_devtensor[0].item() + diceLabel_devtensor[0].item()])
-            log.debug(['diceLoss_devtensor      ', diceLoss_devtensor[0].item()])
+            log.debug(['diceCorrect_g   ', diceCorrect_g[0].item(), correct_tmp[0].min().item(), correct_tmp[0].mean().item(), correct_tmp[0].max().item(), correct_tmp.shape])
+            log.debug(['dicePrediction_g', dicePrediction_g[0].item(), prediction_g[0].min().item(), prediction_g[0].mean().item(), prediction_g[0].max().item(), prediction_g.shape])
+            log.debug(['diceLabel_g     ', diceLabel_g[0].item(), label_g[0].min().item(), label_g[0].mean().item(), label_g[0].max().item(), label_g.shape])
+            log.debug(['2*diceCorrect_g ', 2 * diceCorrect_g[0].item()])
+            log.debug(['Prediction + Label      ', dicePrediction_g[0].item()])
+            log.debug(['diceLoss_g      ', diceLoss_g[0].item()])
+            assert False
 
-        return diceLoss_devtensor
+        return diceLoss_g
 
 
     def logImages(self, epoch_ndx, mode_str, dl):
-        for i, series_uid in enumerate(sorted(dl.dataset.series_list)[:12]):
+        images_iter = sorted(dl.dataset.series_list)[:12]
+        for series_ndx, series_uid in enumerate(images_iter):
             ct = getCt(series_uid)
 
-            for slice_ndx in range(0, ct.ary.shape[0], ct.ary.shape[0] // 5):
-                sample_tup = dl.dataset[(series_uid, slice_ndx, False)]
+            for slice_ndx in range(6):
+                ct_ndx = slice_ndx * ct.hu_a.shape[0] // 5
+                ct_ndx = min(ct_ndx, ct.hu_a.shape[0] - 1)
+                sample_tup = dl.dataset[(series_uid, ct_ndx, False)]
 
-                ct_tensor, nodule_tensor, label_int, ben_tensor, mal_tensor, series_uid, ct_ndx = sample_tup
+                ct_t, nodule_t, _, ben_t, mal_t, _, _ = sample_tup
 
-                ct_tensor[:-1,:,:] += 1000
-                ct_tensor[:-1,:,:] /= 2000
+                ct_t[:-1,:,:] += 1
+                ct_t[:-1,:,:] /= 2
 
-                input_devtensor = ct_tensor.to(self.device)
-                label_devtensor = nodule_tensor.to(self.device)
+                input_g = ct_t.to(self.device)
+                label_g = nodule_t.to(self.device)
 
-                prediction_devtensor = self.model(input_devtensor.unsqueeze(0))[0]
-                prediction_ary = prediction_devtensor.to('cpu').detach().numpy()
-                label_ary = nodule_tensor.numpy()
-                ben_ary = ben_tensor.numpy()
-                mal_ary = mal_tensor.numpy()
+                prediction_g = self.model(input_g.unsqueeze(0))[0]
+                prediction_a = prediction_g.to('cpu').detach().numpy()
+                label_a = nodule_t.numpy()
+                ben_a = ben_t.numpy()
+                mal_a = mal_t.numpy()
 
-                image_ary = np.zeros((512, 512, 3), dtype=np.float32)
-                image_ary[:,:,:] = (ct_tensor[dl.dataset.contextSlices_count].numpy().reshape((512,512,1)))
-                image_ary[:,:,0] += prediction_ary[0] * (1 - label_ary[0])  # Red
-                image_ary[:,:,1] += prediction_ary[0] * mal_ary  # Green
-                image_ary[:,:,2] += prediction_ary[0] * ben_ary  # Blue
+                ctSlice_a = ct_t[dl.dataset.contextSlices_count].numpy()
+
+                image_a = np.zeros((512, 512, 3), dtype=np.float32)
+                image_a[:,:,:] = ctSlice_a.reshape((512,512,1))
+                image_a[:,:,0] += prediction_a[0] * (1 - label_a[0])
+                image_a[:,:,1] += prediction_a[0] * mal_a[0]
+                image_a[:,:,2] += prediction_a[0] * ben_a[0]
+                image_a *= 0.5
+                image_a[image_a < 0] = 0
+                image_a[image_a > 1] = 1
 
                 writer = getattr(self, mode_str + '_writer')
-                image_ary *= 0.5
-                image_ary[image_ary < 0] = 0
-                image_ary[image_ary > 1] = 1
-                writer.add_image('{}/{}_prediction_{}'.format(mode_str, i, slice_ndx), image_ary, self.totalTrainingSamples_count, dataformats='HWC')
+                writer.add_image(
+                    '{}/{}_prediction_{}'.format(
+                        mode_str,
+                        series_ndx,
+                        slice_ndx,
+                    ),
+                    image_a,
+                    self.totalTrainingSamples_count,
+                    dataformats='HWC',
+                )
 
-                # self.diceLoss(label_devtensor, prediction_devtensor, p=True)
+                # self.diceLoss(label_g, prediction_g, p=True)
 
                 if epoch_ndx == 1:
-                    image_ary = np.zeros((512, 512, 3), dtype=np.float32)
-                    image_ary[:,:,:] = (ct_tensor[dl.dataset.contextSlices_count].numpy().reshape((512,512,1)))
-                    image_ary[:,:,0] += (1 - label_ary[0]) * ct_tensor[-1].numpy() # Red
-                    image_ary[:,:,1] += mal_ary  # Green
-                    image_ary[:,:,2] += ben_ary  # Blue
-
-                    writer = getattr(self, mode_str + '_writer')
-                    image_ary *= 0.5
-                    image_ary[image_ary < 0] = 0
-                    image_ary[image_ary > 1] = 1
-                    writer.add_image('{}/{}_label_{}'.format(mode_str, i, slice_ndx), image_ary, self.totalTrainingSamples_count, dataformats='HWC')
+                    image_a = np.zeros((512, 512, 3), dtype=np.float32)
+                    image_a[:,:,:] = ctSlice_a.reshape((512,512,1))
+                    image_a[:,:,0] += (1 - label_a[0]) * ct_t[-1].numpy() # Red
+                    image_a[:,:,1] += mal_a[0]  # Green
+                    image_a[:,:,2] += ben_a[0]  # Blue
+
+                    image_a *= 0.5
+                    image_a[image_a < 0] = 0
+                    image_a[image_a > 1] = 1
+                    writer.add_image(
+                        '{}/{}_label_{}'.format(
+                            mode_str,
+                            series_ndx,
+                            slice_ndx,
+                        ),
+                        image_a,
+                        self.totalTrainingSamples_count,
+                        dataformats='HWC',
+                    )
 
 
     def logMetrics(self,
         epoch_ndx,
         mode_str,
-        metrics_tensor,
+        metrics_t,
     ):
-        self.initTensorboardWriters()
         log.info("E{} {}".format(
             epoch_ndx,
             type(self).__name__,
         ))
 
-        metrics_ary = metrics_tensor.cpu().detach().numpy()
-        sum_ary = metrics_ary.sum(axis=1)
-        assert np.isfinite(metrics_ary).all()
+        metrics_a = metrics_t.cpu().detach().numpy()
+        sum_a = metrics_a.sum(axis=1)
+        assert np.isfinite(metrics_a).all()
 
-        malLabel_mask = (metrics_ary[METRICS_LABEL_NDX] == 1) | (metrics_ary[METRICS_LABEL_NDX] == 3)
+        malLabel_mask = (metrics_a[METRICS_LABEL_NDX] == 1) | (metrics_a[METRICS_LABEL_NDX] == 3)
 
-        # allLabel_mask = (metrics_ary[METRICS_LABEL_NDX] == 2) | (metrics_ary[METRICS_LABEL_NDX] == 3)
+        # allLabel_mask = (metrics_a[METRICS_LABEL_NDX] == 2) | (metrics_a[METRICS_LABEL_NDX] == 3)
 
-        allLabel_count = sum_ary[METRICS_ATP_NDX] + sum_ary[METRICS_AFN_NDX]
-        malLabel_count = sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]
+        allLabel_count = sum_a[METRICS_ATP_NDX] + sum_a[METRICS_AFN_NDX]
+        malLabel_count = sum_a[METRICS_MTP_NDX] + sum_a[METRICS_MFN_NDX]
 
-        allCorrect_count = sum_ary[METRICS_ATP_NDX]
-        malCorrect_count = sum_ary[METRICS_MTP_NDX]
+        # allCorrect_count = sum_a[METRICS_ATP_NDX]
+        # malCorrect_count = sum_a[METRICS_MTP_NDX]
 #
 #             falsePos_count = allLabel_count - allCorrect_count
 #             falseNeg_count = malLabel_count - malCorrect_count
 
 
         metrics_dict = {}
-        metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
-        metrics_dict['loss/mal'] = np.nan_to_num(metrics_ary[METRICS_MAL_LOSS_NDX, malLabel_mask].mean())
-        # metrics_dict['loss/all'] = metrics_ary[METRICS_ALL_LOSS_NDX, allLabel_mask].mean()
+        metrics_dict['loss/all'] = metrics_a[METRICS_LOSS_NDX].mean()
+        metrics_dict['loss/mal'] = np.nan_to_num(metrics_a[METRICS_MAL_LOSS_NDX, malLabel_mask].mean())
+        # metrics_dict['loss/all'] = metrics_a[METRICS_ALL_LOSS_NDX, allLabel_mask].mean()
+
+        # metrics_dict['correct/mal'] = sum_a[METRICS_MTP_NDX] / (sum_a[METRICS_MTP_NDX] + sum_a[METRICS_MFN_NDX]) * 100
+        # metrics_dict['correct/all'] = sum_a[METRICS_ATP_NDX] / (sum_a[METRICS_ATP_NDX] + sum_a[METRICS_AFN_NDX]) * 100
+
+        metrics_dict['percent_all/tp'] = sum_a[METRICS_ATP_NDX] / (allLabel_count or 1) * 100
+        metrics_dict['percent_all/fn'] = sum_a[METRICS_AFN_NDX] / (allLabel_count or 1) * 100
+        metrics_dict['percent_all/fp'] = sum_a[METRICS_AFP_NDX] / (allLabel_count or 1) * 100
 
-        metrics_dict['correct/mal'] = sum_ary[METRICS_MTP_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
-        metrics_dict['correct/all'] = sum_ary[METRICS_ATP_NDX] / (sum_ary[METRICS_ATP_NDX] + sum_ary[METRICS_AFN_NDX]) * 100
+        metrics_dict['percent_mal/tp'] = sum_a[METRICS_MTP_NDX] / (malLabel_count or 1) * 100
+        metrics_dict['percent_mal/fn'] = sum_a[METRICS_MFN_NDX] / (malLabel_count or 1) * 100
 
-        precision = metrics_dict['pr/precision'] = sum_ary[METRICS_ATP_NDX] / ((sum_ary[METRICS_ATP_NDX] + sum_ary[METRICS_AFP_NDX]) or 1)
-        recall    = metrics_dict['pr/recall']    = sum_ary[METRICS_ATP_NDX] / ((sum_ary[METRICS_ATP_NDX] + sum_ary[METRICS_AFN_NDX]) or 1)
+        precision = metrics_dict['pr/precision'] = sum_a[METRICS_ATP_NDX] \
+            / ((sum_a[METRICS_ATP_NDX] + sum_a[METRICS_AFP_NDX]) or 1)
+        recall    = metrics_dict['pr/recall']    = sum_a[METRICS_ATP_NDX] \
+            / ((sum_a[METRICS_ATP_NDX] + sum_a[METRICS_AFN_NDX]) or 1)
 
-        metrics_dict['pr/f1_score'] = 2 * (precision * recall) / ((precision + recall) or 1)
+        metrics_dict['pr/f1_score'] = 2 * (precision * recall) \
+            / ((precision + recall) or 1)
 
         log.info(("E{} {:8} "
                  + "{loss/all:.4f} loss, "
@@ -426,26 +464,30 @@ class LunaTrainingApp(object):
             **metrics_dict,
         ))
         log.info(("E{} {:8} "
-                 + "{loss/all:.4f} loss, "
-                 + "{correct/all:-5.1f}% correct ({allCorrect_count:} of {allLabel_count:})"
+                  + "{loss/all:.4f} loss, "
+                  + "{percent_all/tp:-5.1f}% tp, {percent_all/fn:-5.1f}% fn, {percent_all/fp:-9.1f}% fp"
+                 # + "{correct/all:-5.1f}% correct ({allCorrect_count:} of {allLabel_count:})"
         ).format(
             epoch_ndx,
             mode_str + '_all',
-            allCorrect_count=allCorrect_count,
-            allLabel_count=allLabel_count,
+            # allCorrect_count=allCorrect_count,
+            # allLabel_count=allLabel_count,
             **metrics_dict,
         ))
 
         log.info(("E{} {:8} "
-                 + "{loss/mal:.4f} loss, "
-                 + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
+                  + "{loss/mal:.4f} loss, "
+                  + "{percent_mal/tp:-5.1f}% tp, {percent_mal/fn:-5.1f}% fn"
+                 # + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
         ).format(
             epoch_ndx,
             mode_str + '_mal',
-            malCorrect_count=malCorrect_count,
-            malLabel_count=malLabel_count,
+            # malCorrect_count=malCorrect_count,
+            # malLabel_count=malLabel_count,
             **metrics_dict,
         ))
+
+        self.initTensorboardWriters()
         writer = getattr(self, mode_str + '_writer')
 
         prefix_str = 'seg_'
@@ -454,9 +496,9 @@ class LunaTrainingApp(object):
             writer.add_scalar(prefix_str + key, value, self.totalTrainingSamples_count)
 
         score = 1 \
+            - metrics_dict['loss/mal'] \
             + metrics_dict['pr/f1_score'] \
             - metrics_dict['pr/recall'] * 0.01 \
-            - metrics_dict['loss/mal']  * 0.001 \
             - metrics_dict['loss/all']  * 0.0001
 
         return score
@@ -477,7 +519,7 @@ class LunaTrainingApp(object):
     #             writer.add_histogram(
     #                 name.rsplit('.', 1)[-1] + '/' + name,
     #                 param.data.cpu().numpy(),
-    #                 # metrics_ary[METRICS_PRED_NDX, benHist_mask],
+    #                 # metrics_a[METRICS_PRED_NDX, benHist_mask],
     #                 self.totalTrainingSamples_count,
     #                 # bins=bins,
     #             )

+ 210 - 348
p2ch12/training.py

@@ -1,44 +1,32 @@
 import argparse
 import datetime
 import os
-import socket
 import sys
 
 import numpy as np
-from tensorboardX import SummaryWriter
+
+from torch.utils.tensorboard import SummaryWriter
 
 import torch
 import torch.nn as nn
-import torch.optim
-
-from torch.optim import SGD, Adam
+from torch.optim import SGD
 from torch.utils.data import DataLoader
 
 from util.util import enumerateWithEstimate
-from .dsets import Luna2dSegmentationDataset, TrainingLuna2dSegmentationDataset, getCt
+from .dsets import LunaDataset
 from util.logconf import logging
-from util.util import xyz2irc
-from .model import UNetWrapper
+from .model import LunaModel
 
 log = logging.getLogger(__name__)
 # log.setLevel(logging.WARN)
-# log.setLevel(logging.INFO)
-log.setLevel(logging.DEBUG)
-
-# Used for computeClassificationLoss and logMetrics to index into metrics_tensor/metrics_ary
-METRICS_LABEL_NDX = 0
-METRICS_LOSS_NDX = 1
-METRICS_MAL_LOSS_NDX = 2
-METRICS_BEN_LOSS_NDX = 3
-
-METRICS_MTP_NDX = 4
-METRICS_MFN_NDX = 5
-METRICS_MFP_NDX = 6
-METRICS_BTP_NDX = 7
-METRICS_BFN_NDX = 8
-# METRICS_BFP_NDX = 9
+log.setLevel(logging.INFO)
+# log.setLevel(logging.DEBUG)
 
-METRICS_SIZE = 9
+# Used for computeBatchLoss and logMetrics to index into metrics_t/metrics_a
+METRICS_LABEL_NDX=0
+METRICS_PRED_NDX=1
+METRICS_LOSS_NDX=2
+METRICS_SIZE = 3
 
 class LunaTrainingApp(object):
     def __init__(self, sys_argv=None):
@@ -48,7 +36,7 @@ class LunaTrainingApp(object):
         parser = argparse.ArgumentParser()
         parser.add_argument('--batch-size',
             help='Batch size to use for training',
-            default=24,
+            default=32,
             type=int,
         )
         parser.add_argument('--num-workers',
@@ -61,7 +49,11 @@ class LunaTrainingApp(object):
             default=1,
             type=int,
         )
-
+        parser.add_argument('--balanced',
+            help="Balance the training data to half benign, half malignant.",
+            action='store_true',
+            default=False,
+        )
         parser.add_argument('--augmented',
             help="Augment the training data.",
             action='store_true',
@@ -72,16 +64,16 @@ class LunaTrainingApp(object):
             action='store_true',
             default=False,
         )
-        # parser.add_argument('--augment-offset',
-        #     help="Augment the training data by randomly offsetting the data slightly along the X and Y axes.",
-        #     action='store_true',
-        #     default=False,
-        # )
-        # parser.add_argument('--augment-scale',
-        #     help="Augment the training data by randomly increasing or decreasing the size of the nodule.",
-        #     action='store_true',
-        #     default=False,
-        # )
+        parser.add_argument('--augment-offset',
+            help="Augment the training data by randomly offsetting the data slightly along the X and Y axes.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-scale',
+            help="Augment the training data by randomly increasing or decreasing the size of the nodule.",
+            action='store_true',
+            default=False,
+        )
         parser.add_argument('--augment-rotate',
             help="Augment the training data by randomly rotating the data around the head-foot axis.",
             action='store_true',
@@ -97,62 +89,56 @@ class LunaTrainingApp(object):
             default='p2ch12',
             help="Data prefix to use for Tensorboard run. Defaults to chapter.",
         )
-
         parser.add_argument('comment',
             help="Comment suffix for Tensorboard run.",
             nargs='?',
-            default='none',
+            default='dlwpt',
         )
 
         self.cli_args = parser.parse_args(sys_argv)
         self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
 
         self.trn_writer = None
-        self.tst_writer = None
+        self.val_writer = None
+        self.totalTrainingSamples_count = 0
+
+        self.augmentation_dict = {}
+        if self.cli_args.augmented or self.cli_args.augment_flip:
+            self.augmentation_dict['flip'] = True
+        if self.cli_args.augmented or self.cli_args.augment_offset:
+            self.augmentation_dict['offset'] = 0.1
+        if self.cli_args.augmented or self.cli_args.augment_scale:
+            self.augmentation_dict['scale'] = 0.2
+        if self.cli_args.augmented or self.cli_args.augment_rotate:
+            self.augmentation_dict['rotate'] = True
+        if self.cli_args.augmented or self.cli_args.augment_noise:
+            self.augmentation_dict['noise'] = 25.0
 
         self.use_cuda = torch.cuda.is_available()
         self.device = torch.device("cuda" if self.use_cuda else "cpu")
 
-        # # TODO: remove this if block before print
-        # # This is due to an odd setup that the author is using to test the code; please ignore for now
-        # if socket.gethostname() == 'c2':
-        #     self.device = torch.device("cuda:1")
-
         self.model = self.initModel()
         self.optimizer = self.initOptimizer()
 
-        self.totalTrainingSamples_count = 0
-
-        augmentation_dict = {}
-        if self.cli_args.augmented or self.cli_args.augment_flip:
-            augmentation_dict['flip'] = True
-        if self.cli_args.augmented or self.cli_args.augment_rotate:
-            augmentation_dict['rotate'] = True
-        if self.cli_args.augmented or self.cli_args.augment_noise:
-            augmentation_dict['noise'] = 25.0
-        self.augmentation_dict = augmentation_dict
-
 
     def initModel(self):
-        # model = UNetWrapper(in_channels=8, n_classes=2, depth=3, wf=6, padding=True, batch_norm=True, up_mode='upconv')
-        model = UNetWrapper(in_channels=7, n_classes=1, depth=4, wf=3, padding=True, batch_norm=True, up_mode='upconv')
-
+        model = LunaModel()
         if self.use_cuda:
+            log.info("Using CUDA with {} devices.".format(torch.cuda.device_count()))
             if torch.cuda.device_count() > 1:
                 model = nn.DataParallel(model)
             model = model.to(self.device)
-
         return model
 
     def initOptimizer(self):
-        return SGD(self.model.parameters(), lr=0.01, momentum=0.99)
+        return SGD(self.model.parameters(), lr=0.001, momentum=0.99)
         # return Adam(self.model.parameters())
 
-
     def initTrainDl(self):
-        train_ds = TrainingLuna2dSegmentationDataset(
-            test_stride=10,
-            contextSlices_count=3,
+        train_ds = LunaDataset(
+            val_stride=10,
+            isValSet_bool=False,
+            ratio_int=int(self.cli_args.balanced),
             augmentation_dict=self.augmentation_dict,
         )
 
@@ -165,69 +151,61 @@ class LunaTrainingApp(object):
 
         return train_dl
 
-    def initTestDl(self):
-        test_ds = Luna2dSegmentationDataset(
-            test_stride=10,
-            contextSlices_count=3,
+    def initValDl(self):
+        val_ds = LunaDataset(
+            val_stride=10,
+            isValSet_bool=True,
         )
 
-        test_dl = DataLoader(
-            test_ds,
+        val_dl = DataLoader(
+            val_ds,
             batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
             num_workers=self.cli_args.num_workers,
             pin_memory=self.use_cuda,
         )
 
-        return test_dl
+        return val_dl
 
     def initTensorboardWriters(self):
         if self.trn_writer is None:
             log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
 
-            self.trn_writer = SummaryWriter(log_dir=log_dir + '_trn_seg_' + self.cli_args.comment)
-            self.tst_writer = SummaryWriter(log_dir=log_dir + '_tst_seg_' + self.cli_args.comment)
+            self.trn_writer = SummaryWriter(log_dir=log_dir + '-trn_cls-' + self.cli_args.comment)
+            self.val_writer = SummaryWriter(log_dir=log_dir + '-val_cls-' + self.cli_args.comment)
 
 
     def main(self):
         log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
 
         train_dl = self.initTrainDl()
-        test_dl = self.initTestDl()
-
-        self.initTensorboardWriters()
-        # self.logModelMetrics(self.model)
-
-        best_score = 0.0
+        val_dl = self.initValDl()
 
         for epoch_ndx in range(1, self.cli_args.epochs + 1):
+
             log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
                 epoch_ndx,
                 self.cli_args.epochs,
                 len(train_dl),
-                len(test_dl),
+                len(val_dl),
                 self.cli_args.batch_size,
                 (torch.cuda.device_count() if self.use_cuda else 1),
             ))
 
-            trainingMetrics_tensor = self.doTraining(epoch_ndx, train_dl)
-            self.logPerformanceMetrics(epoch_ndx, 'trn', trainingMetrics_tensor)
-            self.logImages(epoch_ndx, train_dl, test_dl)
-            # self.logModelMetrics(self.model)
-
-            testingMetrics_tensor = self.doTesting(epoch_ndx, test_dl)
-            score = self.logPerformanceMetrics(epoch_ndx, 'tst', testingMetrics_tensor)
-            best_score = max(score, best_score)
+            trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
+            self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)
 
-            self.saveModel('seg' if self.cli_args.segmentation else 'cls', epoch_ndx, score == best_score)
+            valMetrics_t = self.doValidation(epoch_ndx, val_dl)
+            self.logMetrics(epoch_ndx, 'val', valMetrics_t)
 
         if hasattr(self, 'trn_writer'):
             self.trn_writer.close()
-            self.tst_writer.close()
+            self.val_writer.close()
+
 
     def doTraining(self, epoch_ndx, train_dl):
-        trainingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(train_dl.dataset)).to(self.device)
         self.model.train()
-
+        train_dl.dataset.shuffleSamples()
+        trnMetrics_g = torch.zeros(METRICS_SIZE, len(train_dl.dataset)).to(self.device)
         batch_iter = enumerateWithEstimate(
             train_dl,
             "E{} Training".format(epoch_ndx),
@@ -236,267 +214,177 @@ class LunaTrainingApp(object):
         for batch_ndx, batch_tup in batch_iter:
             self.optimizer.zero_grad()
 
-            loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_devtensor)
-            loss_var.backward()
+            loss_var = self.computeBatchLoss(
+                batch_ndx,
+                batch_tup,
+                train_dl.batch_size,
+                trnMetrics_g
+            )
 
+            loss_var.backward()
             self.optimizer.step()
             del loss_var
 
-        self.totalTrainingSamples_count += trainingMetrics_devtensor.size(1)
+        self.totalTrainingSamples_count += len(train_dl.dataset)
 
-        return trainingMetrics_devtensor.to('cpu')
+        return trnMetrics_g.to('cpu')
 
-    def doTesting(self, epoch_ndx, test_dl):
+
+    def doValidation(self, epoch_ndx, val_dl):
         with torch.no_grad():
-            testingMetrics_devtensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset)).to(self.device)
             self.model.eval()
-
+            valMetrics_g = torch.zeros(METRICS_SIZE, len(val_dl.dataset)).to(self.device)
             batch_iter = enumerateWithEstimate(
-                test_dl,
-                "E{} Testing ".format(epoch_ndx),
-                start_ndx=test_dl.num_workers,
+                val_dl,
+                "E{} Validation ".format(epoch_ndx),
+                start_ndx=val_dl.num_workers,
             )
             for batch_ndx, batch_tup in batch_iter:
-                self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_devtensor)
-
-        return testingMetrics_devtensor.to('cpu')
-
-
-    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_devtensor):
-        input_tensor, label_tensor, label_list, ben_tensor, mal_tensor, _series_list, _start_list = batch_tup
-
-        input_devtensor = input_tensor.to(self.device, non_blocking=True)
-        label_devtensor = label_tensor.to(self.device, non_blocking=True)
-        mal_devtensor = mal_tensor.to(self.device, non_blocking=True)
-        ben_devtensor = ben_tensor.to(self.device, non_blocking=True)
-
-        start_ndx = batch_ndx * batch_size
-        end_ndx = start_ndx + label_tensor.size(0)
-        intersectionSum = lambda a, b: (a * b).view(a.size(0), -1).sum(dim=1)
-
-        prediction_devtensor = self.model(input_devtensor)
-        diceLoss_devtensor = self.diceLoss(label_devtensor, prediction_devtensor)
-
-        with torch.no_grad():
-            predictionBool_devtensor = (prediction_devtensor > 0.5).to(torch.float32)
-
-            metrics_devtensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_list
-            metrics_devtensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_devtensor
-
-            malPred_devtensor = predictionBool_devtensor * (1 - ben_devtensor)
-
-            tp = intersectionSum(    mal_devtensor,     malPred_devtensor)
-            fn = intersectionSum(    mal_devtensor, 1 - malPred_devtensor)
-            fp = intersectionSum(1 - mal_devtensor,     malPred_devtensor)
-            ls = self.diceLoss(mal_devtensor, malPred_devtensor)
-
-            metrics_devtensor[METRICS_MTP_NDX, start_ndx:end_ndx] = tp
-            metrics_devtensor[METRICS_MFN_NDX, start_ndx:end_ndx] = fn
-            metrics_devtensor[METRICS_MFP_NDX, start_ndx:end_ndx] = fp
-            metrics_devtensor[METRICS_MAL_LOSS_NDX, start_ndx:end_ndx] = ls
-
-            del malPred_devtensor, tp, fn, fp, ls
-
-            benPred_devtensor = predictionBool_devtensor * (1 - mal_devtensor)
-            tp = intersectionSum(    ben_devtensor,     benPred_devtensor)
-            fn = intersectionSum(    ben_devtensor, 1 - benPred_devtensor)
-            # fp = intersectionSum(1 - ben_devtensor,     benPred_devtensor)
-            ls = self.diceLoss(ben_devtensor, benPred_devtensor)
+                self.computeBatchLoss(batch_ndx, batch_tup, val_dl.batch_size, valMetrics_g)
 
-            metrics_devtensor[METRICS_BTP_NDX, start_ndx:end_ndx] = tp
-            metrics_devtensor[METRICS_BFN_NDX, start_ndx:end_ndx] = fn
-            # metrics_devtensor[METRICS_BFP_NDX, start_ndx:end_ndx] = fp
-            metrics_devtensor[METRICS_BEN_LOSS_NDX, start_ndx:end_ndx] = ls
+        return valMetrics_g.to('cpu')
 
-            del benPred_devtensor, tp, fn, ls
 
-        return diceLoss_devtensor.mean()
 
-    def diceLoss(self, label_devtensor, prediction_devtensor, epsilon=0.01, p=False):
-        sum_dim1 = lambda t: t.view(t.size(0), -1).sum(dim=1)
+    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g):
+        input_t, label_t, _series_list, _center_list = batch_tup
 
-        diceLabel_devtensor = sum_dim1(label_devtensor)
-        dicePrediction_devtensor = sum_dim1(prediction_devtensor)
-        diceCorrect_devtensor = sum_dim1(prediction_devtensor * label_devtensor)
+        input_g = input_t.to(self.device, non_blocking=True)
+        label_g = label_t.to(self.device, non_blocking=True)
 
-        epsilon_devtensor = torch.ones_like(diceCorrect_devtensor) * epsilon
-        diceLoss_devtensor = 1 - (2 * diceCorrect_devtensor + epsilon_devtensor) / (dicePrediction_devtensor + diceLabel_devtensor + epsilon_devtensor)
+        logits_g, probability_g = self.model(input_g)
 
-        return diceLoss_devtensor
-
-
-
-    def logImages(self, epoch_ndx, train_dl, test_dl):
-        for mode_str, dl in [('trn', train_dl), ('tst', test_dl)]:
-            for i, series_uid in enumerate(sorted(dl.dataset.series_list)[:12]):
-                ct = getCt(series_uid)
-                noduleInfo_tup = (ct.malignantInfo_list or ct.benignInfo_list)[0]
-                center_irc = xyz2irc(noduleInfo_tup.center_xyz, ct.origin_xyz, ct.vxSize_xyz, ct.direction_tup)
-
-                sample_tup = dl.dataset[(series_uid, int(center_irc.index))]
-                # input_tensor = sample_tup[0].unsqueeze(0)
-                # label_tensor = sample_tup[1].unsqueeze(0)
-
-                input_tensor, label_tensor, ben_tensor, mal_tensor = sample_tup[:4]
-                input_tensor += 1000
-                input_tensor /= 2001
-
-                input_devtensor = input_tensor.to(self.device)
-                # label_devtensor = label_tensor.to(self.device)
-
-                prediction_devtensor = self.model(input_devtensor.unsqueeze(0))[0]
-                prediction_ary = prediction_devtensor.to('cpu').detach().numpy()
-                label_ary = label_tensor.numpy()
-                ben_ary = ben_tensor.numpy()
-                mal_ary = mal_tensor.numpy()
-
-                # log.debug([prediction_ary[0].shape, label_ary.shape, mal_ary.shape])
-
-                image_ary = np.zeros((512, 512, 3), dtype=np.float32)
-                image_ary[:,:,:] = (input_tensor[dl.dataset.contextSlices_count].numpy().reshape((512,512,1))) * 0.5
-                image_ary[:,:,0] += prediction_ary[0] * (1 - label_ary[0]) * 0.5
-                image_ary[:,:,1] += prediction_ary[0] * mal_ary * 0.5
-                image_ary[:,:,2] += prediction_ary[0] * ben_ary * 0.5
-                # image_ary[:,:,2] += prediction_ary[0,1] * 0.25
-                # image_ary[:,:,2] += prediction_ary[0,2] * 0.5
-
-                # log.debug([image_ary.__array_interface__['typestr']])
-
-                # image_ary = (image_ary * 255).astype(np.uint8)
-
-                # log.debug([image_ary.__array_interface__['typestr']])
-
-                writer = getattr(self, mode_str + '_writer')
-                try:
-                    image_ary[image_ary < 0] = 0
-                    image_ary[image_ary > 1] = 1
-                    writer.add_image('{}/{}_pred'.format(mode_str, i), image_ary, self.totalTrainingSamples_count, dataformats='HWC')
-                except:
-                    log.debug([image_ary.shape, image_ary.dtype])
-                    raise
-
-                if epoch_ndx == 1:
-
-                    image_ary = np.zeros((512, 512, 3), dtype=np.float32)
-                    image_ary[:,:,:] = (input_tensor[dl.dataset.contextSlices_count].numpy().reshape((512,512,1))) * 0.5
-                    image_ary[:,:,1] += mal_ary * 0.5
-                    image_ary[:,:,2] += ben_ary * 0.5
-                    # image_ary[:,:,2] += label_ary[0,1] * 0.25
-                    # image_ary[:,:,2] += (input_tensor[0,-1].numpy() - (label_ary[0,0].astype(np.bool) | label_ary[0,1].astype(np.bool))) * 0.25
-
-                    # log.debug([image_ary.__array_interface__['typestr']])
-
-                    # image_ary = (image_ary * 255).astype(np.uint8)
+        loss_func = nn.CrossEntropyLoss(reduction='none')
+        loss_g = loss_func(
+            logits_g,
+            label_g[:,1],
+        )
+        start_ndx = batch_ndx * batch_size
+        end_ndx = start_ndx + label_t.size(0)
 
-                    # log.debug([image_ary.__array_interface__['typestr']])
+        metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_g[:,1]
+        metrics_g[METRICS_PRED_NDX, start_ndx:end_ndx] = probability_g[:,1]
+        metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_g
 
-                    writer = getattr(self, mode_str + '_writer')
-                    image_ary[image_ary < 0] = 0
-                    image_ary[image_ary > 1] = 1
-                    writer.add_image('{}/{}_label'.format(mode_str, i), image_ary, self.totalTrainingSamples_count, dataformats='HWC')
+        return loss_g.mean()
 
 
-    def logPerformanceMetrics(self,
-                              epoch_ndx,
-                              mode_str,
-                              metrics_tensor,
-                              ):
+    def logMetrics(
+            self,
+            epoch_ndx,
+            mode_str,
+            metrics_t,
+    ):
+        self.initTensorboardWriters()
         log.info("E{} {}".format(
             epoch_ndx,
             type(self).__name__,
         ))
 
+        benLabel_mask = metrics_t[METRICS_LABEL_NDX] <= 0.5
+        benPred_mask = metrics_t[METRICS_PRED_NDX] <= 0.5
 
-        # for mode_str, metrics_tensor in [('trn', trainingMetrics_tensor), ('tst', testingMetrics_tensor)]:
-        metrics_ary = metrics_tensor.cpu().detach().numpy()
-        sum_ary = metrics_ary.sum(axis=1)
-        assert np.isfinite(metrics_ary).all()
+        malLabel_mask = ~benLabel_mask
+        malPred_mask = ~benPred_mask
 
-        malLabel_mask = (metrics_ary[METRICS_LABEL_NDX] == 1) | (metrics_ary[METRICS_LABEL_NDX] == 3)
+        ben_count = int(benLabel_mask.sum())
+        mal_count = int(malLabel_mask.sum())
 
-        benLabel_mask = (metrics_ary[METRICS_LABEL_NDX] == 2) | (metrics_ary[METRICS_LABEL_NDX] == 3)
-        # malFound_mask = metrics_ary[METRICS_MFOUND_NDX] > classificationThreshold_float
+        trueNeg_count = ben_correct = int((benLabel_mask & benPred_mask).sum())
+        truePos_count = mal_correct = int((malLabel_mask & malPred_mask).sum())
 
-        # malLabel_mask = ~benLabel_mask
-        # malPred_mask = ~benPred_mask
+        falsePos_count = ben_count - ben_correct
+        falseNeg_count = mal_count - mal_correct
 
-        benLabel_count = sum_ary[METRICS_BTP_NDX] + sum_ary[METRICS_BFN_NDX]
-        malLabel_count = sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]
+        metrics_dict = {}
+        metrics_dict['loss/all'] = metrics_t[METRICS_LOSS_NDX].mean()
+        metrics_dict['loss/ben'] = metrics_t[METRICS_LOSS_NDX, benLabel_mask].mean()
+        metrics_dict['loss/mal'] = metrics_t[METRICS_LOSS_NDX, malLabel_mask].mean()
 
-        trueNeg_count = benCorrect_count = sum_ary[METRICS_BTP_NDX]
-        truePos_count = malCorrect_count = sum_ary[METRICS_MTP_NDX]
-#
-#             falsePos_count = benLabel_count - benCorrect_count
-#             falseNeg_count = malLabel_count - malCorrect_count
+        metrics_dict['correct/all'] = (mal_correct + ben_correct) / metrics_t.shape[1] * 100
+        metrics_dict['correct/ben'] = (ben_correct) / ben_count * 100
+        metrics_dict['correct/mal'] = (mal_correct) / mal_count * 100
 
+        precision = metrics_dict['pr/precision'] = \
+            truePos_count / np.float64(truePos_count + falsePos_count)
+        recall    = metrics_dict['pr/recall'] = \
+            truePos_count / np.float64(truePos_count + falseNeg_count)
 
-        metrics_dict = {}
-        metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
-        # metrics_dict['loss/msk'] = metrics_ary[METRICS_MASKLOSS_NDX].mean()
-        # metrics_dict['loss/mal'] = metrics_ary[METRICS_MALLOSS_NDX].mean()
-        # metrics_dict['loss/lng'] = metrics_ary[METRICS_LUNG_LOSS_NDX, benLabel_mask].mean()
-        metrics_dict['loss/mal'] = np.nan_to_num(metrics_ary[METRICS_MAL_LOSS_NDX, malLabel_mask].mean())
-        metrics_dict['loss/ben'] = metrics_ary[METRICS_BEN_LOSS_NDX, benLabel_mask].mean()
-        # metrics_dict['loss/flg'] = metrics_ary[METRICS_FLG_LOSS_NDX].mean()
-
-        # metrics_dict['flagged/all'] = sum_ary[METRICS_MOK_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
-        # metrics_dict['flagged/slices'] = (malLabel_mask & malFound_mask).sum() / malLabel_mask.sum() * 100
-
-        metrics_dict['correct/mal'] = sum_ary[METRICS_MTP_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
-        metrics_dict['correct/ben'] = sum_ary[METRICS_BTP_NDX] / (sum_ary[METRICS_BTP_NDX] + sum_ary[METRICS_BFN_NDX]) * 100
-
-        precision = metrics_dict['pr/precision'] = sum_ary[METRICS_MTP_NDX] / ((sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFP_NDX]) or 1)
-        recall    = metrics_dict['pr/recall']    = sum_ary[METRICS_MTP_NDX] / ((sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) or 1)
-
-        metrics_dict['pr/f1_score'] = 2 * (precision * recall) / ((precision + recall) or 1)
-
-        log.info(("E{} {:8} "
-                 + "{loss/all:.4f} loss, "
-                 # + "{loss/flg:.4f} flagged loss, "
-                 # + "{flagged/all:-5.1f}% pixels flagged, "
-                 # + "{flagged/slices:-5.1f}% slices flagged, "
+        metrics_dict['pr/f1_score'] = \
+            2 * (precision * recall) / (precision + recall)
+
+        log.info(
+            ("E{} {:8} {loss/all:.4f} loss, "
+                 + "{correct/all:-5.1f}% correct, "
                  + "{pr/precision:.4f} precision, "
                  + "{pr/recall:.4f} recall, "
                  + "{pr/f1_score:.4f} f1 score"
-                  ).format(
-            epoch_ndx,
-            mode_str,
-            **metrics_dict,
-        ))
-        log.info(("E{} {:8} "
-                 + "{loss/mal:.4f} loss, "
-                 + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
-        ).format(
-            epoch_ndx,
-            mode_str + '_mal',
-            malCorrect_count=malCorrect_count,
-            malLabel_count=malLabel_count,
-            **metrics_dict,
-        ))
-        log.info(("E{} {:8} "
-                 + "{loss/ben:.4f} loss, "
-                 + "{correct/ben:-5.1f}% correct ({benCorrect_count:} of {benLabel_count:})"
-        ).format(
-            epoch_ndx,
-            mode_str + '_ben',
-            benCorrect_count=benCorrect_count,
-            benLabel_count=benLabel_count,
-            **metrics_dict,
-        ))
-
+            ).format(
+                epoch_ndx,
+                mode_str,
+                **metrics_dict,
+            )
+        )
+        log.info(
+            ("E{} {:8} {loss/ben:.4f} loss, "
+                 + "{correct/ben:-5.1f}% correct ({ben_correct:} of {ben_count:})"
+            ).format(
+                epoch_ndx,
+                mode_str + '_ben',
+                ben_correct=ben_correct,
+                ben_count=ben_count,
+                **metrics_dict,
+            )
+        )
+        log.info(
+            ("E{} {:8} {loss/mal:.4f} loss, "
+                 + "{correct/mal:-5.1f}% correct ({mal_correct:} of {mal_count:})"
+            ).format(
+                epoch_ndx,
+                mode_str + '_mal',
+                mal_correct=mal_correct,
+                mal_count=mal_count,
+                **metrics_dict,
+            )
+        )
         writer = getattr(self, mode_str + '_writer')
 
-        prefix_str = 'seg_'
-
         for key, value in metrics_dict.items():
-            writer.add_scalar(prefix_str + key, value, self.totalTrainingSamples_count)
+            writer.add_scalar(key, value, self.totalTrainingSamples_count)
 
-        score = 1 \
-            + metrics_dict['pr/f1_score'] \
-            - metrics_dict['loss/mal'] * 0.01 \
-            - metrics_dict['loss/all'] * 0.0001
+        writer.add_pr_curve(
+            'pr',
+            metrics_t[METRICS_LABEL_NDX],
+            metrics_t[METRICS_PRED_NDX],
+            self.totalTrainingSamples_count,
+        )
+
+        bins = [x/50.0 for x in range(51)]
 
-        return score
+        benHist_mask = benLabel_mask & (metrics_t[METRICS_PRED_NDX] > 0.01)
+        malHist_mask = malLabel_mask & (metrics_t[METRICS_PRED_NDX] < 0.99)
+
+        if benHist_mask.any():
+            writer.add_histogram(
+                'is_ben',
+                metrics_t[METRICS_PRED_NDX, benHist_mask],
+                self.totalTrainingSamples_count,
+                bins=bins,
+            )
+        if malHist_mask.any():
+            writer.add_histogram(
+                'is_mal',
+                metrics_t[METRICS_PRED_NDX, malHist_mask],
+                self.totalTrainingSamples_count,
+                bins=bins,
+            )
+
+        # score = 1 \
+        #     + metrics_dict['pr/f1_score'] \
+        #     - metrics_dict['loss/mal'] * 0.01 \
+        #     - metrics_dict['loss/all'] * 0.0001
+        #
+        # return score
 
     # def logModelMetrics(self, model):
     #     writer = getattr(self, 'trn_writer')
@@ -511,44 +399,18 @@ class LunaTrainingApp(object):
     #
     #             # bins = [x/50*max_extent for x in range(-50, 51)]
     #
-    #             writer.add_histogram(
-    #                 name.rsplit('.', 1)[-1] + '/' + name,
-    #                 param.data.cpu().numpy(),
-    #                 # metrics_ary[METRICS_PRED_NDX, benHist_mask],
-    #                 self.totalTrainingSamples_count,
-    #                 # bins=bins,
-    #             )
-    #
-    #             # print name, param.data
-
-    def saveModel(self, type_str, epoch_ndx, isBest=False):
-        file_path = os.path.join('data-unversioned', 'models', self.cli_args.tb_prefix, '{}_{}_{}.{}.state'.format(type_str, self.time_str, self.cli_args.comment, self.totalTrainingSamples_count))
-
-        os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
-
-        model = self.model
-        if hasattr(model, 'module'):
-            model = model.module
-
-        state = {
-            'model_state': model.state_dict(),
-            'model_name': type(model).__name__,
-            'optimizer_state' : self.optimizer.state_dict(),
-            'optimizer_name': type(self.optimizer).__name__,
-            'epoch': epoch_ndx,
-            'totalTrainingSamples_count': self.totalTrainingSamples_count,
-            # 'resumed_from': self.cli_args.resume,
-        }
-        torch.save(state, file_path)
-
-        log.debug("Saved model params to {}".format(file_path))
-
-        if isBest:
-            file_path = os.path.join('data-unversioned', 'models', self.cli_args.tb_prefix, '{}_{}_{}.{}.state'.format(type_str, self.time_str, self.cli_args.comment, 'best'))
-            torch.save(state, file_path)
-
-            log.debug("Saved model params to {}".format(file_path))
+    #             try:
+    #                 writer.add_histogram(
+    #                     name.rsplit('.', 1)[-1] + '/' + name,
+    #                     param.data.cpu().numpy(),
+    #                     # metrics_a[METRICS_PRED_NDX, benHist_mask],
+    #                     self.totalTrainingSamples_count,
+    #                     # bins=bins,
+    #                 )
+    #             except Exception as e:
+    #                 log.error([min_data, max_data])
+    #                 raise
 
 
 if __name__ == '__main__':
-    sys.exit(LunaTrainingApp().main() or 0)
+    LunaTrainingApp().main()

+ 44 - 31
p2ch12/vis.py

@@ -6,14 +6,14 @@ import matplotlib.pyplot as plt
 
 from p2ch12.dsets import Ct, LunaDataset
 
-clim=(0.0, 1.3)
+clim=(-1000.0, 300)
 
 def findMalignantSamples(start_ndx=0, limit=10):
-    ds = LunaDataset()
+    ds = LunaDataset(sortby_str='malignancy_size')
 
     malignantSample_list = []
-    for sample_tup in ds.sample_list:
-        if sample_tup[2]:
+    for sample_tup in ds.noduleInfo_list:
+        if sample_tup.isMalignant_bool:
             print(len(malignantSample_list), sample_tup)
             malignantSample_list.append(sample_tup)
 
@@ -24,7 +24,7 @@ def findMalignantSamples(start_ndx=0, limit=10):
 
 def showNodule(series_uid, batch_ndx=None, **kwargs):
     ds = LunaDataset(series_uid=series_uid, **kwargs)
-    malignant_list = [i for i, x in enumerate(ds.sample_list) if x[2]]
+    malignant_list = [i for i, x in enumerate(ds.noduleInfo_list) if x.isMalignant_bool]
 
     if batch_ndx is None:
         if malignant_list:
@@ -34,53 +34,66 @@ def showNodule(series_uid, batch_ndx=None, **kwargs):
             batch_ndx = 0
 
     ct = Ct(series_uid)
-    # ct_tensor, malignant_tensor, series_uid, center_irc = ds[batch_ndx]
-    malignant_tensor, diameter_mm, series_uid, center_irc, nodule_tensor = ds[batch_ndx]
-    ct_ary = nodule_tensor[1].numpy()
+    ct_t, malignant_t, series_uid, center_irc = ds[batch_ndx]
+    ct_a = ct_t[0].numpy()
 
-
-    fig = plt.figure(figsize=(15, 25))
+    fig = plt.figure(figsize=(30, 50))
 
     group_list = [
-        #[0,1,2],
-        [9,11,13],
+        [9, 11, 13],
         [15, 16, 17],
-        [19,21,23],
-        #[12,13,14],
-        #[15]
+        [19, 21, 23],
     ]
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 1)
-    subplot.set_title('index {}'.format(int(center_irc.index)))
-    plt.imshow(ct.ary[int(center_irc.index)], clim=clim, cmap='gray')
+    subplot.set_title('index {}'.format(int(center_irc.index)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct.hu_a[int(center_irc.index)], clim=clim, cmap='gray')
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 2)
-    subplot.set_title('row {}'.format(int(center_irc.row)))
-    plt.imshow(ct.ary[:,int(center_irc.row)], clim=clim, cmap='gray')
+    subplot.set_title('row {}'.format(int(center_irc.row)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct.hu_a[:,int(center_irc.row)], clim=clim, cmap='gray')
+    plt.gca().invert_yaxis()
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 3)
-    subplot.set_title('col {}'.format(int(center_irc.col)))
-    plt.imshow(ct.ary[:,:,int(center_irc.col)], clim=clim, cmap='gray')
+    subplot.set_title('col {}'.format(int(center_irc.col)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct.hu_a[:,:,int(center_irc.col)], clim=clim, cmap='gray')
+    plt.gca().invert_yaxis()
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 4)
-    subplot.set_title('index {}'.format(int(center_irc.index)))
-    plt.imshow(ct_ary[ct_ary.shape[0]//2], clim=clim, cmap='gray')
+    subplot.set_title('index {}'.format(int(center_irc.index)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct_a[ct_a.shape[0]//2], clim=clim, cmap='gray')
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 5)
-    subplot.set_title('row {}'.format(int(center_irc.row)))
-    plt.imshow(ct_ary[:,ct_ary.shape[1]//2], clim=clim, cmap='gray')
+    subplot.set_title('row {}'.format(int(center_irc.row)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct_a[:,ct_a.shape[1]//2], clim=clim, cmap='gray')
+    plt.gca().invert_yaxis()
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 6)
-    subplot.set_title('col {}'.format(int(center_irc.col)))
-    plt.imshow(ct_ary[:,:,ct_ary.shape[2]//2], clim=clim, cmap='gray')
+    subplot.set_title('col {}'.format(int(center_irc.col)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct_a[:,:,ct_a.shape[2]//2], clim=clim, cmap='gray')
+    plt.gca().invert_yaxis()
 
     for row, index_list in enumerate(group_list):
         for col, index in enumerate(index_list):
             subplot = fig.add_subplot(len(group_list) + 2, 3, row * 3 + col + 7)
-            subplot.set_title('slice {}'.format(index))
-            plt.imshow(ct_ary[index*2], clim=clim, cmap='gray')
+            subplot.set_title('slice {}'.format(index), fontsize=30)
+            for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+                label.set_fontsize(20)
+            plt.imshow(ct_a[index], clim=clim, cmap='gray')
+
 
+    print(series_uid, batch_ndx, bool(malignant_t[0]), malignant_list)
 
-    print(series_uid, batch_ndx, bool(malignant_tensor[0]), malignant_list, ct.vxSize_xyz)
 
-    return ct_ary

File diff suppressed because it is too large
+ 4 - 40
p2ch12_explore_data.ipynb


File diff suppressed because it is too large
+ 14 - 23
p2ch12_explore_diagnose.ipynb


+ 0 - 0
p2ch13/__init__.py


+ 372 - 0
p2ch13/diagnose.py

@@ -0,0 +1,372 @@
+import argparse
+import glob
+import os
+import sys
+
+import numpy as np
+import scipy.ndimage.measurements as measure
+import scipy.ndimage.morphology as morph
+
+import torch
+import torch.nn as nn
+import torch.optim
+
+from torch.utils.data import DataLoader
+
+from util.util import enumerateWithEstimate
+from .dsets import LunaDataset, Luna2dSegmentationDataset, getCt, getNoduleInfoList, NoduleInfoTuple
+from .model_seg import UNetWrapper
+from .model_cls import LunaModel, AlternateLunaModel
+
+from util.logconf import logging
+from util.util import xyz2irc, irc2xyz
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+
+class LunaDiagnoseApp(object):
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            log.debug(sys.argv)
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=4,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+
+        parser.add_argument('--series-uid',
+            help='Limit inference to this Series UID only.',
+            default=None,
+            type=str,
+        )
+
+        parser.add_argument('--include-train',
+            help="Include data that was in the training set. (default: validation data only)",
+            action='store_true',
+            default=False,
+        )
+
+        parser.add_argument('--segmentation-path',
+            help="Path to the saved segmentation model",
+            nargs='?',
+            default=None,
+        )
+
+        parser.add_argument('--classification-path',
+            help="Path to the saved classification model",
+            nargs='?',
+            default=None,
+        )
+
+        parser.add_argument('--tb-prefix',
+            default='p2ch13',
+            help="Data prefix to use for Tensorboard run. Defaults to chapter.",
+        )
+
+        self.cli_args = parser.parse_args(sys_argv)
+        # self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
+
+        self.use_cuda = torch.cuda.is_available()
+        self.device = torch.device("cuda" if self.use_cuda else "cpu")
+
+        if not self.cli_args.segmentation_path:
+            self.cli_args.segmentation_path = self.initModelPath('seg')
+
+        if not self.cli_args.classification_path:
+            self.cli_args.classification_path = self.initModelPath('cls')
+
+        self.seg_model, self.cls_model = self.initModels()
+
+    def initModelPath(self, type_str):
+        local_path = os.path.join(
+            'data-unversioned',
+            'part2',
+            'models',
+            self.cli_args.tb_prefix,
+            type_str + '_{}_{}.{}.state'.format('*', '*', 'best'),
+        )
+
+        file_list = glob.glob(local_path)
+        if not file_list:
+            pretrained_path = os.path.join(
+                'data',
+                'part2',
+                'models',
+                type_str + '_{}_{}.{}.state'.format('*', '*', '*'),
+            )
+            file_list = glob.glob(pretrained_path)
+        else:
+            pretrained_path = None
+
+        file_list.sort()
+
+        try:
+            return file_list[-1]
+        except IndexError:
+            log.debug([local_path, pretrained_path, file_list])
+            raise
+
+    def initModels(self):
+        log.debug(self.cli_args.segmentation_path)
+        seg_dict = torch.load(self.cli_args.segmentation_path)
+
+        seg_model = UNetWrapper(in_channels=8, n_classes=1, depth=4, wf=3, padding=True, batch_norm=True, up_mode='upconv')
+        seg_model.load_state_dict(seg_dict['model_state'])
+        seg_model.eval()
+
+        log.debug(self.cli_args.classification_path)
+        cls_dict = torch.load(self.cli_args.classification_path)
+
+        cls_model = LunaModel()
+        # cls_model = AlternateLunaModel()
+        cls_model.load_state_dict(cls_dict['model_state'])
+        cls_model.eval()
+
+        if self.use_cuda:
+            if torch.cuda.device_count() > 1:
+                seg_model = nn.DataParallel(seg_model)
+                cls_model = nn.DataParallel(cls_model)
+
+            seg_model = seg_model.to(self.device)
+            cls_model = cls_model.to(self.device)
+
+        return seg_model, cls_model
+
+
+    def initSegmentationDl(self, series_uid):
+        seg_ds = Luna2dSegmentationDataset(
+                contextSlices_count=3,
+                series_uid=series_uid,
+                fullCt_bool=True,
+            )
+        seg_dl = DataLoader(
+            seg_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return seg_dl
+
+    def initClassificationDl(self, noduleInfo_list):
+        cls_ds = LunaDataset(
+                sortby_str='series_uid',
+                noduleInfo_list=noduleInfo_list,
+            )
+        cls_dl = DataLoader(
+            cls_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return cls_dl
+
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        val_ds = LunaDataset(
+            val_stride=10,
+            isValSet_bool=True,
+        )
+        val_set = set(
+            noduleInfo_tup.series_uid
+            for noduleInfo_tup in val_ds.noduleInfo_list
+        )
+        malignant_set = set(
+            noduleInfo_tup.series_uid
+            for noduleInfo_tup in getNoduleInfoList()
+            if noduleInfo_tup.isMalignant_bool
+        )
+
+        if self.cli_args.series_uid:
+            series_set = set(self.cli_args.series_uid.split(','))
+        else:
+            series_set = set(
+                noduleInfo_tup.series_uid
+                for noduleInfo_tup in getNoduleInfoList()
+            )
+
+        train_list = sorted(series_set - val_set) if self.cli_args.include_train else []
+        val_list = sorted(series_set & val_set)
+
+
+        noduleInfo_list = []
+        series_iter = enumerateWithEstimate(
+            val_list + train_list,
+            "Series",
+        )
+        for _series_ndx, series_uid in series_iter:
+            ct, output_a, _mask_a, clean_a = self.segmentCt(series_uid)
+
+            noduleInfo_list += self.clusterSegmentationOutput(
+                series_uid,
+                ct,
+                clean_a,
+            )
+
+            # if _series_ndx > 10:
+            #     break
+
+
+        cls_dl = self.initClassificationDl(noduleInfo_list)
+
+        series2diagnosis_dict = {}
+        batch_iter = enumerateWithEstimate(
+            cls_dl,
+            "Cls all",
+            start_ndx=cls_dl.num_workers,
+        )
+        for batch_ndx, batch_tup in batch_iter:
+            input_t, _, series_list, center_list = batch_tup
+
+            input_g = input_t.to(self.device)
+            with torch.no_grad():
+                _logits_g, probability_g = self.cls_model(input_g)
+
+            classifications_list = zip(
+                series_list,
+                center_list,
+                probability_g[:,1].to('cpu'),
+            )
+
+            for cls_tup in classifications_list:
+                series_uid, center_irc, probablity_t = cls_tup
+                probablity_float = probablity_t.item()
+
+                this_tup = (probablity_float, tuple(center_irc))
+                current_tup = series2diagnosis_dict.get(series_uid, this_tup)
+                try:
+                    assert np.all(np.isfinite(tuple(center_irc)))
+                    if this_tup > current_tup:
+                        log.debug([series_uid, this_tup])
+                    series2diagnosis_dict[series_uid] = max(this_tup, current_tup)
+                except:
+                    log.debug([(type(x), x) for x in this_tup] + [(type(x), x) for x in current_tup])
+                    raise
+
+        log.info('Training set:')
+        self.logResults('Training', train_list, series2diagnosis_dict, malignant_set)
+
+        log.info('Validation set:')
+        self.logResults('Validation', val_list, series2diagnosis_dict, malignant_set)
+
+    def segmentCt(self, series_uid):
+        with torch.no_grad():
+            ct = getCt(series_uid)
+
+            output_a = np.zeros_like(ct.hu_a, dtype=np.float32)
+
+            seg_dl = self.initSegmentationDl(series_uid)
+            for batch_tup in seg_dl:
+                input_t = batch_tup[0]
+                ndx_list = batch_tup[6]
+
+                input_g = input_t.to(self.device)
+                prediction_g = self.seg_model(input_g)
+
+                for i, sample_ndx in enumerate(ndx_list):
+                    output_a[sample_ndx] = prediction_g[i].cpu().numpy()
+
+            mask_a = output_a > 0.5
+            clean_a = morph.binary_erosion(mask_a, iterations=1)
+            clean_a = morph.binary_dilation(clean_a, iterations=2)
+
+        return ct, output_a, mask_a, clean_a
+
+    def clusterSegmentationOutput(self, series_uid,  ct, clean_a):
+        noduleLabel_a, nodule_count = measure.label(clean_a)
+        centerIrc_list = measure.center_of_mass(
+            ct.hu_a + 1001,
+            labels=noduleLabel_a,
+            index=list(range(1, nodule_count+1)),
+        )
+
+        # n = 1298
+        # log.debug([
+        #     (noduleLabel_a == n).sum(),
+        #     np.where(noduleLabel_a == n),
+        #
+        #     ct.hu_a[noduleLabel_a == n].sum(),
+        #     (ct.hu_a + 1000)[noduleLabel_a == n].sum(),
+        # ])
+
+        # if nodule_count == 1:
+        #     centerIrc_list = [centerIrc_list]
+
+        noduleInfo_list = []
+        for i, center_irc in enumerate(centerIrc_list):
+            center_xyz = irc2xyz(
+                center_irc,
+                ct.origin_xyz,
+                ct.vxSize_xyz,
+                ct.direction_tup,
+            )
+            assert np.all(np.isfinite(center_irc)), repr(['irc', center_irc, i, nodule_count])
+            assert np.all(np.isfinite(center_xyz)), repr(['xyz', center_xyz])
+            noduleInfo_tup = \
+                NoduleInfoTuple(False, 0.0, series_uid, center_xyz)
+            noduleInfo_list.append(noduleInfo_tup)
+
+        return noduleInfo_list
+
+    def logResults(self, mode_str, filtered_list, series2diagnosis_dict, malignant_set):
+        count_dict = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}
+        for series_uid in filtered_list:
+            probablity_float, center_irc = series2diagnosis_dict.get(series_uid, (0.0, None))
+            if center_irc is not None:
+                center_irc = tuple(int(x.item()) for x in center_irc)
+            malignant_bool = series_uid in malignant_set
+            prediction_bool = probablity_float > 0.5
+            correct_bool = malignant_bool == prediction_bool
+
+            if malignant_bool and prediction_bool:
+                count_dict['tp'] += 1
+            if not malignant_bool and not prediction_bool:
+                count_dict['tn'] += 1
+            if not malignant_bool and prediction_bool:
+                count_dict['fp'] += 1
+            if malignant_bool and not prediction_bool:
+                count_dict['fn'] += 1
+
+
+            log.info("{} {} Mal:{!r:5} Pred:{!r:5} Correct?:{!r:5} Value:{:.4f} {}".format(
+                mode_str,
+                series_uid,
+                malignant_bool,
+                prediction_bool,
+                correct_bool,
+                probablity_float,
+                center_irc,
+            ))
+
+        total_count = sum(count_dict.values())
+        percent_dict = {k: v / (total_count or 1) * 100 for k, v in count_dict.items()}
+
+        precision = percent_dict['p'] = count_dict['tp'] / ((count_dict['tp'] + count_dict['fp']) or 1)
+        recall    = percent_dict['r'] = count_dict['tp'] / ((count_dict['tp'] + count_dict['fn']) or 1)
+        percent_dict['f1'] = 2 * (precision * recall) / ((precision + recall) or 1)
+
+        log.info(mode_str + " tp:{tp:.1f}%, tn:{tn:.1f}%, fp:{fp:.1f}%, fn:{fn:.1f}%".format(
+            **percent_dict,
+        ))
+        log.info(mode_str + " precision:{p:.3f}, recall:{r:.3f}, F1:{f1:.3f}".format(
+            **percent_dict,
+        ))
+
+
+
+if __name__ == '__main__':
+    sys.exit(LunaDiagnoseApp().main() or 0)

+ 566 - 0
p2ch13/dsets.py

@@ -0,0 +1,566 @@
+import copy
+import csv
+import functools
+import glob
+import math
+import os
+import random
+
+from collections import namedtuple
+
+import SimpleITK as sitk
+import numpy as np
+import scipy.ndimage.morphology as morph
+
+import torch
+import torch.cuda
+import torch.nn.functional as F
+from torch.utils.data import Dataset
+
+from util.disk import getCache
+from util.util import XyzTuple, xyz2irc
+from util.logconf import logging
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+raw_cache = getCache('part2ch13_raw')
+
+NoduleInfoTuple = namedtuple('NoduleInfoTuple', 'isMalignant_bool, diameter_mm, series_uid, center_xyz')
+MaskTuple = namedtuple('MaskTuple', 'raw_dense_mask, dense_mask, body_mask, air_mask, raw_nodule_mask, nodule_mask, lung_mask, ben_mask, mal_mask')
+
+@functools.lru_cache(1)
+def getNoduleInfoList(requireDataOnDisk_bool=True):
+    # We construct a set with all series_uids that are present on disk.
+    # This will let us use the data, even if we haven't downloaded all of
+    # the subsets yet.
+    mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
+    dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
+
+    diameter_dict = {}
+    with open('data/part2/luna/annotations.csv', "r") as f:
+        for row in list(csv.reader(f))[1:]:
+            series_uid = row[0]
+            annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
+            annotationDiameter_mm = float(row[4])
+
+            diameter_dict.setdefault(series_uid, []).append((annotationCenter_xyz, annotationDiameter_mm))
+
+    noduleInfo_list = []
+    with open('data/part2/luna/candidates.csv', "r") as f:
+        for row in list(csv.reader(f))[1:]:
+            series_uid = row[0]
+
+            if series_uid not in dataPresentOnDisk_set and requireDataOnDisk_bool:
+                continue
+
+            isMalignant_bool = bool(int(row[4]))
+            candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
+
+            candidateDiameter_mm = 0.0
+            for annotationCenter_xyz, annotationDiameter_mm in diameter_dict.get(series_uid, []):
+                for i in range(3):
+                    delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
+                    if delta_mm > annotationDiameter_mm / 4:
+                        break
+                else:
+                    candidateDiameter_mm = annotationDiameter_mm
+                    break
+
+            noduleInfo_list.append(NoduleInfoTuple(isMalignant_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))
+
+    noduleInfo_list.sort(reverse=True)
+    return noduleInfo_list
+
+class Ct(object):
+    def __init__(self, series_uid, buildMasks_bool=True):
+        mhd_path = glob.glob('data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
+
+        ct_mhd = sitk.ReadImage(mhd_path)
+        ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
+
+        # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
+        # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
+        # This gets rid of negative density stuff used to indicate out-of-FOV
+        ct_a[ct_a < -1000] = -1000
+
+        # This nukes any weird hotspots and clamps bone down
+        ct_a[ct_a > 1000] = 1000
+
+        self.series_uid = series_uid
+        self.hu_a = ct_a
+
+        self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
+        self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
+        self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
+
+        noduleInfo_list = getNoduleInfoList()
+        self.benignInfo_list = [ni_tup
+                                for ni_tup in noduleInfo_list
+                                    if not ni_tup.isMalignant_bool
+                                        and ni_tup.series_uid == self.series_uid]
+        self.benign_mask = self.buildAnnotationMask(self.benignInfo_list)[0]
+        self.benign_indexes = sorted(set(self.benign_mask.nonzero()[0]))
+
+        self.malignantInfo_list = [ni_tup
+                                   for ni_tup in noduleInfo_list
+                                        if ni_tup.isMalignant_bool
+                                            and ni_tup.series_uid == self.series_uid]
+        self.malignant_mask = self.buildAnnotationMask(self.malignantInfo_list)[0]
+        self.malignant_indexes = sorted(set(self.malignant_mask.nonzero()[0]))
+
+    def buildAnnotationMask(self, noduleInfo_list, threshold_hu = -500):
+        boundingBox_a = np.zeros_like(self.hu_a, dtype=np.bool)
+
+        for noduleInfo_tup in noduleInfo_list:
+            center_irc = xyz2irc(
+                noduleInfo_tup.center_xyz,
+                self.origin_xyz,
+                self.vxSize_xyz,
+                self.direction_tup,
+            )
+            ci = int(center_irc.index)
+            cr = int(center_irc.row)
+            cc = int(center_irc.col)
+
+            index_radius = 2
+            try:
+                while self.hu_a[ci + index_radius, cr, cc] > threshold_hu and \
+                        self.hu_a[ci - index_radius, cr, cc] > threshold_hu:
+                    index_radius += 1
+            except IndexError:
+                index_radius -= 1
+
+            row_radius = 2
+            try:
+                while self.hu_a[ci, cr + row_radius, cc] > threshold_hu and \
+                        self.hu_a[ci, cr - row_radius, cc] > threshold_hu:
+                    row_radius += 1
+            except IndexError:
+                row_radius -= 1
+
+            col_radius = 2
+            try:
+                while self.hu_a[ci, cr, cc + col_radius] > threshold_hu and \
+                        self.hu_a[ci, cr, cc - col_radius] > threshold_hu:
+                    col_radius += 1
+            except IndexError:
+                col_radius -= 1
+
+            # assert index_radius > 0, repr([noduleInfo_tup.center_xyz, center_irc, self.hu_a[ci, cr, cc]])
+            # assert row_radius > 0
+            # assert col_radius > 0
+
+
+            slice_tup = (
+                slice(ci - index_radius, ci + index_radius + 1),
+                slice(cr - row_radius, cr + row_radius + 1),
+                slice(cc - col_radius, cc + row_radius + 1),
+            )
+            boundingBox_a[slice_tup] = True
+
+        thresholded_a = boundingBox_a & (self.hu_a > threshold_hu)
+        mask_a = morph.binary_dilation(thresholded_a, iterations=2)
+
+        return mask_a, thresholded_a, boundingBox_a
+
+    def build2dLungMask(self, mask_ndx):
+        raw_dense_mask = self.hu_a[mask_ndx] > -300
+        dense_mask = morph.binary_closing(raw_dense_mask, iterations=2)
+        dense_mask = morph.binary_opening(dense_mask, iterations=2)
+
+        body_mask = morph.binary_fill_holes(dense_mask)
+        air_mask = morph.binary_fill_holes(body_mask & ~dense_mask)
+        air_mask = morph.binary_erosion(air_mask, iterations=1)
+
+        lung_mask = morph.binary_dilation(air_mask, iterations=5)
+
+        raw_nodule_mask = self.hu_a[mask_ndx] > -600
+        raw_nodule_mask &= air_mask
+        nodule_mask = morph.binary_opening(raw_nodule_mask, iterations=1)
+
+        ben_mask = morph.binary_dilation(nodule_mask, iterations=1)
+        ben_mask &= ~self.malignant_mask[mask_ndx]
+
+        mal_mask = self.malignant_mask[mask_ndx]
+
+        return MaskTuple(
+            raw_dense_mask,
+            dense_mask,
+            body_mask,
+            air_mask,
+            raw_nodule_mask,
+            nodule_mask,
+            lung_mask,
+            ben_mask,
+            mal_mask,
+        )
+
+
+    def getRawNodule(self, center_xyz, width_irc):
+        center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
+
+        slice_list = []
+        for axis, center_val in enumerate(center_irc):
+            start_ndx = int(round(center_val - width_irc[axis]/2))
+            end_ndx = int(start_ndx + width_irc[axis])
+
+            assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
+
+            if start_ndx < 0:
+                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
+                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
+                start_ndx = 0
+                end_ndx = int(width_irc[axis])
+
+            if end_ndx > self.hu_a.shape[axis]:
+                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
+                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
+                end_ndx = self.hu_a.shape[axis]
+                start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])
+
+            slice_list.append(slice(start_ndx, end_ndx))
+
+        ct_chunk = self.hu_a[tuple(slice_list)]
+
+        return ct_chunk, center_irc
+
+ctCache_depth = 5
+@functools.lru_cache(ctCache_depth, typed=True)
+def getCt(series_uid):
+    return Ct(series_uid)
+
+@raw_cache.memoize(typed=True)
+def getCtRawNodule(series_uid, center_xyz, width_irc):
+    ct = getCt(series_uid)
+    ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
+    return ct_chunk, center_irc
+
+@raw_cache.memoize(typed=True)
+def getCtSampleSize(series_uid):
+    ct = Ct(series_uid, buildMasks_bool=False)
+    return len(ct.benign_indexes)
+
+def getCtAugmentedNodule(
+        augmentation_dict,
+        series_uid, center_xyz, width_irc,
+        use_cache=True):
+    if use_cache:
+        ct_chunk, center_irc = getCtRawNodule(series_uid, center_xyz, width_irc)
+    else:
+        ct = getCt(series_uid)
+        ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
+
+    ct_t = torch.tensor(ct_chunk).unsqueeze(0).unsqueeze(0).to(torch.float32)
+
+    transform_t = torch.eye(4).to(torch.float64)
+
+    for i in range(3):
+        if 'flip' in augmentation_dict:
+            if random.random() > 0.5:
+                transform_t[i,i] *= -1
+
+        if 'offset' in augmentation_dict:
+            offset_float = augmentation_dict['offset']
+            random_float = (random.random() * 2 - 1)
+            transform_t[3,i] = offset_float * random_float
+
+        if 'scale' in augmentation_dict:
+            scale_float = augmentation_dict['scale']
+            random_float = (random.random() * 2 - 1)
+            transform_t[i,i] *= 1.0 + scale_float * random_float
+
+    if 'rotate' in augmentation_dict:
+        angle_rad = random.random() * math.pi * 2
+        s = math.sin(angle_rad)
+        c = math.cos(angle_rad)
+
+        rotation_t = torch.tensor([
+            [c, -s, 0, 0],
+            [s, c, 0, 0],
+            [0, 0, 1, 0],
+            [0, 0, 0, 1],
+        ], dtype=torch.float64)
+
+        transform_t @= rotation_t
+
+    affine_t = F.affine_grid(
+            transform_t[:3].unsqueeze(0).to(torch.float32),
+            ct_t.size(),
+        )
+
+    augmented_chunk = F.grid_sample(
+            ct_t,
+            affine_t,
+            padding_mode='border'
+        ).to('cpu')
+
+    if 'noise' in augmentation_dict:
+        noise_t = torch.randn_like(augmented_chunk)
+        noise_t *= augmentation_dict['noise']
+
+        augmented_chunk += noise_t
+
+    return augmented_chunk[0], center_irc
+
+
+class LunaDataset(Dataset):
+    def __init__(self,
+                 val_stride=0,
+                 isValSet_bool=None,
+                 series_uid=None,
+                 sortby_str='random',
+                 ratio_int=0,
+                 augmentation_dict=None,
+                 noduleInfo_list=None,
+            ):
+        self.ratio_int = ratio_int
+        self.augmentation_dict = augmentation_dict
+
+        if noduleInfo_list:
+            self.noduleInfo_list = copy.copy(noduleInfo_list)
+            self.use_cache = False
+        else:
+            self.noduleInfo_list = copy.copy(getNoduleInfoList())
+            self.use_cache = True
+
+        if series_uid:
+            self.series_list = [series_uid]
+        else:
+            self.series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
+
+        if isValSet_bool:
+            assert val_stride > 0, val_stride
+            self.series_list = self.series_list[::val_stride]
+            assert self.series_list
+        elif val_stride > 0:
+            del self.series_list[::val_stride]
+            assert self.series_list
+
+        series_set = set(self.series_list)
+        self.noduleInfo_list = [x for x in self.noduleInfo_list if x.series_uid in series_set]
+
+        if sortby_str == 'random':
+            random.shuffle(self.noduleInfo_list)
+        elif sortby_str == 'series_uid':
+            self.noduleInfo_list.sort(key=lambda x: (x.series_uid, x.center_xyz))
+        elif sortby_str == 'malignancy_size':
+            pass
+        else:
+            raise Exception("Unknown sort: " + repr(sortby_str))
+
+        self.benign_list = [nt for nt in self.noduleInfo_list if not nt.isMalignant_bool]
+        self.malignant_list = [nt for nt in self.noduleInfo_list if nt.isMalignant_bool]
+
+        log.info("{!r}: {} {} samples, {} ben, {} mal, {} ratio".format(
+            self,
+            len(self.noduleInfo_list),
+            "validation" if isValSet_bool else "training",
+            len(self.benign_list),
+            len(self.malignant_list),
+            '{}:1'.format(self.ratio_int) if self.ratio_int else 'unbalanced'
+        ))
+
+    def shuffleSamples(self):
+        if self.ratio_int:
+            random.shuffle(self.benign_list)
+            random.shuffle(self.malignant_list)
+
+    def __len__(self):
+        if self.ratio_int:
+            # return 20000
+            return 200000
+        else:
+            return len(self.noduleInfo_list)
+
+    def __getitem__(self, ndx):
+        if self.ratio_int:
+            malignant_ndx = ndx // (self.ratio_int + 1)
+
+            if ndx % (self.ratio_int + 1):
+                benign_ndx = ndx - 1 - malignant_ndx
+                benign_ndx %= len(self.benign_list)
+                nodule_tup = self.benign_list[benign_ndx]
+            else:
+                malignant_ndx %= len(self.malignant_list)
+                nodule_tup = self.malignant_list[malignant_ndx]
+        else:
+            nodule_tup = self.noduleInfo_list[ndx]
+
+        width_irc = (32, 48, 48)
+
+        if self.augmentation_dict:
+            nodule_t, center_irc = getCtAugmentedNodule(
+                self.augmentation_dict,
+                nodule_tup.series_uid,
+                nodule_tup.center_xyz,
+                width_irc,
+                self.use_cache,
+            )
+        elif self.use_cache:
+            nodule_a, center_irc = getCtRawNodule(
+                nodule_tup.series_uid,
+                nodule_tup.center_xyz,
+                width_irc,
+            )
+            nodule_t = torch.from_numpy(nodule_a).to(torch.float32)
+            nodule_t = nodule_t.unsqueeze(0)
+        else:
+            ct = getCt(nodule_tup.series_uid)
+            nodule_a, center_irc = ct.getRawNodule(
+                nodule_tup.center_xyz,
+                width_irc,
+            )
+            nodule_t = torch.from_numpy(nodule_a).to(torch.float32)
+            nodule_t = nodule_t.unsqueeze(0)
+
+        malignant_t = torch.tensor([
+                not nodule_tup.isMalignant_bool,
+                nodule_tup.isMalignant_bool
+            ],
+            dtype=torch.long,
+        )
+
+        # log.debug([type(center_irc), center_irc])
+
+        return nodule_t, malignant_t, nodule_tup.series_uid, torch.tensor(center_irc)
+
+
+class PrepcacheLunaDataset(LunaDataset):
+    def __getitem__(self, ndx):
+        nodule_t, malignant_t, series_uid, center_t = super().__getitem__(ndx)
+        getCtSampleSize(series_uid)
+        return nodule_t, malignant_t, series_uid, center_t
+
+
+class Luna2dSegmentationDataset(Dataset):
+    def __init__(self,
+                 val_stride=0,
+                 isValSet_bool=None,
+                 series_uid=None,
+                 contextSlices_count=2,
+                 augmentation_dict=None,
+                 fullCt_bool=False,
+            ):
+        self.contextSlices_count = contextSlices_count
+        self.augmentation_dict = augmentation_dict
+
+        if series_uid:
+            self.series_list = [series_uid]
+        else:
+            self.series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
+
+        if isValSet_bool:
+            assert val_stride > 0, val_stride
+            self.series_list = self.series_list[::val_stride]
+            assert self.series_list
+        elif val_stride > 0:
+            del self.series_list[::val_stride]
+            assert self.series_list
+
+        self.sample_list = []
+        for series_uid in self.series_list:
+            if fullCt_bool:
+                self.sample_list.extend([(series_uid, ct_ndx) for ct_ndx in range(getCt(series_uid).hu_a.shape[0])])
+            else:
+                self.sample_list.extend([(series_uid, ct_ndx) for ct_ndx in range(getCtSampleSize(series_uid))])
+
+        log.info("{!r}: {} {} series, {} slices".format(
+            self,
+            len(self.series_list),
+            {None: 'general', True: 'validation', False: 'training'}[isValSet_bool],
+            len(self.sample_list),
+        ))
+
+    def __len__(self):
+        return len(self.sample_list) #// 100
+
+    def __getitem__(self, ndx):
+        if isinstance(ndx, int):
+            series_uid, sample_ndx = self.sample_list[ndx % len(self.sample_list)]
+            ct = getCt(series_uid)
+            ct_ndx = self.sample_list[sample_ndx][1]
+            useAugmentation_bool = False
+        else:
+            series_uid, ct_ndx, useAugmentation_bool = ndx
+            ct = getCt(series_uid)
+
+        ct_t = torch.zeros((self.contextSlices_count * 2 + 1 + 1, 512, 512))
+
+        start_ndx = ct_ndx - self.contextSlices_count
+        end_ndx = ct_ndx + self.contextSlices_count + 1
+        for i, context_ndx in enumerate(range(start_ndx, end_ndx)):
+            context_ndx = max(context_ndx, 0)
+            context_ndx = min(context_ndx, ct.hu_a.shape[0] - 1)
+
+            ct_t[i] = torch.from_numpy(ct.hu_a[context_ndx].astype(np.float32))
+        ct_t /= 1000
+
+        mask_tup = ct.build2dLungMask(ct_ndx)
+
+        ct_t[-1] = torch.from_numpy(mask_tup.lung_mask.astype(np.float32))
+
+        nodule_t = torch.from_numpy(
+            (mask_tup.mal_mask | mask_tup.ben_mask).astype(np.float32)
+        ).unsqueeze(0)
+        ben_t = torch.from_numpy(mask_tup.ben_mask.astype(np.float32)).unsqueeze(0)
+        mal_t = torch.from_numpy(mask_tup.mal_mask.astype(np.float32)).unsqueeze(0)
+        label_int = mal_t.max() + ben_t.max() * 2
+
+        if self.augmentation_dict and useAugmentation_bool:
+            if 'rotate' in self.augmentation_dict:
+                if random.random() > 0.5:
+                    ct_t = ct_t.rot90(1, [1, 2])
+                    nodule_t = nodule_t.rot90(1, [1, 2])
+
+            if 'flip' in self.augmentation_dict:
+                dims = [d+1 for d in range(2) if random.random() > 0.5]
+
+                if dims:
+                    ct_t = ct_t.flip(dims)
+                    nodule_t = nodule_t.flip(dims)
+
+            if 'noise' in self.augmentation_dict:
+                noise_t = torch.randn_like(ct_t)
+                noise_t *= self.augmentation_dict['noise']
+
+                ct_t += noise_t
+        return ct_t, nodule_t, label_int, ben_t, mal_t, ct.series_uid, ct_ndx
+
+
+class TrainingLuna2dSegmentationDataset(Luna2dSegmentationDataset):
+    def __init__(self, *args, batch_size=80, **kwargs):
+        self.needsShuffle_bool = True
+        self.batch_size = batch_size
+        # self.rotate_frac = 0.5 * len(self.series_list) / len(self)
+        super().__init__(*args, **kwargs)
+
+    def __len__(self):
+        return 50000
+
+    def __getitem__(self, ndx):
+        if self.needsShuffle_bool:
+            random.shuffle(self.series_list)
+            self.needsShuffle_bool = False
+
+        if isinstance(ndx, int):
+            if ndx % self.batch_size == 0:
+                self.series_list.append(self.series_list.pop(0))
+
+            series_uid = self.series_list[ndx % ctCache_depth]
+            ct = getCt(series_uid)
+
+            if ndx % 3 == 0:
+                ct_ndx = random.choice(ct.malignant_indexes or ct.benign_indexes)
+            elif ndx % 3 == 1:
+                ct_ndx = random.choice(ct.benign_indexes)
+            elif ndx % 3 == 2:
+                ct_ndx = random.choice(list(range(ct.hu_a.shape[0])))
+
+            useAugmentation_bool = True
+        else:
+            series_uid, ct_ndx, useAugmentation_bool = ndx
+
+        return super().__getitem__((series_uid, ct_ndx, useAugmentation_bool))

+ 68 - 0
p2ch13/model.py

@@ -0,0 +1,68 @@
+
+import torch
+from torch import nn as nn
+
+from util.logconf import logging
+from util.unet import UNet
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+class LunaModel(nn.Module):
+    def __init__(self, layer_count=4, in_channels=1, conv_channels=8):
+        super().__init__()
+
+        layer_list = []
+        for layer_ndx in range(layer_count):
+            layer_list += [
+                nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=False),
+                nn.BatchNorm3d(conv_channels), # eli: will assume that p1ch6 doesn't use this
+                nn.LeakyReLU(inplace=True), # eli: will assume plan ReLU
+                nn.Dropout3d(p=0.2),  # eli: will assume that p1ch6 doesn't use this
+
+                nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=False),
+                nn.BatchNorm3d(conv_channels),
+                nn.LeakyReLU(inplace=True),
+                nn.Dropout3d(p=0.2),
+
+                nn.MaxPool3d(2, 2),
+ # tag::model_init[]
+           ]
+
+            in_channels = conv_channels
+            conv_channels *= 2
+
+        self.convAndPool_seq = nn.Sequential(*layer_list)
+        self.fullyConnected_layer = nn.Linear(512, 1)
+        self.final = nn.Hardtanh(min_val=0.0, max_val=1.0)
+
+
+    def forward(self, input_batch):
+        conv_output = self.convAndPool_seq(input_batch)
+        conv_flat = conv_output.view(conv_output.size(0), -1)
+
+        try:
+            classifier_output = self.fullyConnected_layer(conv_flat)
+        except:
+            log.debug(conv_flat.size())
+            raise
+
+        classifier_output = self.final(classifier_output)
+        return classifier_output
+
+class UNetWrapper(nn.Module):
+    def __init__(self, **kwargs):
+        super().__init__()
+
+        self.batchnorm = nn.BatchNorm2d(kwargs['in_channels'])
+        self.unet = UNet(**kwargs)
+        self.hardtanh = nn.Hardtanh(min_val=0, max_val=1)
+
+    def forward(self, input):
+        bn_output = self.batchnorm(input)
+        un_output = self.unet(bn_output)
+        ht_output = self.hardtanh(un_output)
+
+        return ht_output

+ 92 - 0
p2ch13/model_cls.py

@@ -0,0 +1,92 @@
+import math
+
+import torch.nn as nn
+
+from util.logconf import logging
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+
+class LunaModel(nn.Module):
+    def __init__(self, in_channels=1, conv_channels=8):
+        super().__init__()
+
+        self.tail_batchnorm = nn.BatchNorm3d(1)
+
+        self.block1 = LunaBlock(in_channels, conv_channels)
+        self.block2 = LunaBlock(conv_channels, conv_channels * 2)
+        self.block3 = LunaBlock(conv_channels * 2, conv_channels * 4)
+        self.block4 = LunaBlock(conv_channels * 4, conv_channels * 8)
+
+        self.head_linear = nn.Linear(1152, 2)
+        self.head_softmax = nn.Softmax(dim=1)
+
+        self._init_weights()
+
+    # see also https://github.com/pytorch/pytorch/issues/18182
+    def _init_weights(self):
+        for m in self.modules():
+            if type(m) in {
+                nn.Linear,
+                nn.Conv3d,
+                nn.Conv2d,
+                nn.ConvTranspose2d,
+                nn.ConvTranspose3d,
+            }:
+                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out', nonlinearity='relu')
+                if m.bias is not None:
+                    fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
+                    bound = 1 / math.sqrt(fan_out)
+                    nn.init.normal_(m.bias, -bound, bound)
+
+
+    def forward(self, input_batch):
+        bn_output = self.tail_batchnorm(input_batch)
+
+        block_out = self.block1(bn_output)
+        block_out = self.block2(block_out)
+        block_out = self.block3(block_out)
+        block_out = self.block4(block_out)
+
+        conv_flat = block_out.view(
+            block_out.size(0),
+            -1,
+        )
+        linear_output = self.head_linear(conv_flat)
+
+        return linear_output, self.head_softmax(linear_output)
+
+
+class LunaBlock(nn.Module):
+    def __init__(self, in_channels, conv_channels):
+        super().__init__()
+
+        self.conv1 = nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True)
+        self.relu1 = nn.ReLU(inplace=True)
+        self.conv2 = nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=True)
+        self.relu2 = nn.ReLU(inplace=True)
+
+        self.maxpool = nn.MaxPool3d(2, 2)
+
+    def forward(self, input_batch):
+        block_out = self.conv1(input_batch)
+        block_out = self.relu1(block_out)
+        block_out = self.conv2(block_out)
+        block_out = self.relu2(block_out)
+
+        return self.maxpool(block_out)
+
+
+class AlternateLunaModel(LunaModel):
+    def __init__(self, in_channels=1, conv_channels=64):
+        super().__init__()
+
+        self.block1 = LunaBlock(in_channels, conv_channels)
+        self.block2 = LunaBlock(conv_channels, conv_channels // 2)
+        self.block3 = LunaBlock(conv_channels // 2, conv_channels // 4)
+        self.block4 = LunaBlock(conv_channels // 4, conv_channels // 8)
+
+        self.head_linear = nn.Linear(144, 2)

+ 46 - 0
p2ch13/model_seg.py

@@ -0,0 +1,46 @@
+import math
+
+from torch import nn as nn
+
+from util.logconf import logging
+from util.unet import UNet
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+class UNetWrapper(nn.Module):
+    def __init__(self, **kwargs):
+        super().__init__()
+
+        self.input_batchnorm = nn.BatchNorm2d(kwargs['in_channels'])
+        self.unet = UNet(**kwargs)
+        self.final = nn.Sigmoid()
+
+        self._init_weights()
+
+    def _init_weights(self):
+        init_set = {
+            nn.Conv2d,
+            nn.Conv3d,
+            nn.ConvTranspose2d,
+            nn.ConvTranspose3d,
+            nn.Linear,
+        }
+        for m in self.modules():
+            if type(m) in init_set:
+                nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu', a=0)
+                if m.bias is not None:
+                    fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
+                    bound = 1 / math.sqrt(fan_out)
+                    nn.init.normal_(m.bias, -bound, bound)
+
+
+    def forward(self, input_batch):
+        bn_output = self.input_batchnorm(input_batch)
+        un_output = self.unet(bn_output)
+        fn_output = self.final(un_output)
+
+        return fn_output
+

+ 328 - 0
p2ch13/model_segmentation.py

@@ -0,0 +1,328 @@
+import torch
+from torch import nn as nn
+
+from util.logconf import logging
+from util.unet import UNet
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+# torch.backends.cudnn.enabled = False
+
+class UNetWrapper(nn.Module):
+    def __init__(self, **kwargs):
+        super().__init__()
+
+        self.batchnorm = nn.BatchNorm2d(kwargs['in_channels'])
+        self.unet = UNet(**kwargs)
+        self.hardtanh = nn.Hardtanh(min_val=0, max_val=1)
+
+    def forward(self, input):
+        bn_output = self.batchnorm(input)
+        un_output = self.unet(bn_output)
+        ht_output = self.hardtanh(un_output)
+
+        return ht_output
+
+
+
+class Simple2dSegmentationModel(nn.Module):
+    def __init__(self, layers, in_channels, conv_channels, final_channels):
+        super().__init__()
+        self.layers = layers
+
+        self.in_channels = in_channels
+        self.conv_channels = conv_channels
+        self.final_channels = final_channels
+
+        layer_list = [
+            nn.Conv2d(self.in_channels, self.conv_channels, kernel_size=3, padding=1),
+            nn.BatchNorm2d(self.conv_channels),
+            # nn.GroupNorm(1, self.conv_channels),
+            # nn.ReLU(inplace=True),
+            nn.LeakyReLU(inplace=True),
+        ]
+
+        for i in range(self.layers):
+            layer_list.extend([
+                nn.Conv2d(self.conv_channels, self.conv_channels, kernel_size=3, padding=1),
+                nn.BatchNorm2d(self.conv_channels),
+                # nn.GroupNorm(1, self.conv_channels),
+                # nn.ReLU(inplace=True),
+                nn.LeakyReLU(inplace=True),
+            ])
+
+        layer_list.extend([
+            nn.Conv2d(self.conv_channels, self.final_channels, kernel_size=1, bias=True),
+            nn.Hardtanh(min_val=0, max_val=1),
+        ])
+
+        self.layer_seq = nn.Sequential(*layer_list)
+
+
+    def forward(self, in_data):
+        return self.layer_seq(in_data)
+
+
+class Dense2dSegmentationModel(nn.Module):
+    def __init__(self, layers, input_channels, conv_channels, bottleneck_channels, final_channels):
+        super().__init__()
+        self.layers = layers
+
+        self.input_channels = input_channels
+        self.conv_channels = conv_channels
+        self.bottleneck_channels = bottleneck_channels
+        self.final_channels = final_channels
+
+        self.layer_list = nn.ModuleList()
+
+        for i in range(layers):
+            self.layer_list.append(
+                Dense2dSegmentationBlock(
+                    input_channels + bottleneck_channels * i,
+                    conv_channels,
+                    bottleneck_channels,
+                )
+            )
+
+        self.layer_list.append(
+            Dense2dSegmentationBlock(
+                input_channels + bottleneck_channels * layers,
+                conv_channels,
+                bottleneck_channels,
+                final_channels,
+            )
+        )
+
+        self.htanh_layer = nn.Hardtanh(min_val=0, max_val=1)
+
+    def forward(self, input_tensor):
+        concat_list = [input_tensor]
+        for layer_block in self.layer_list:
+            layer_output = layer_block(torch.cat(concat_list, dim=1))
+            concat_list.append(layer_output)
+
+        return self.htanh_layer(concat_list[-1])
+
+
+class Dense2dSegmentationBlock(nn.Module):
+    def __init__(self, input_channels, conv_channels, bottleneck_channels, final_channels=None):
+        super().__init__()
+
+        self.input_channels = input_channels
+        self.conv_channels = conv_channels
+        self.bottleneck_channels = bottleneck_channels
+        self.final_channels = final_channels or bottleneck_channels
+
+        self.conv1_seq = nn.Sequential(
+            nn.Conv2d(self.input_channels, self.bottleneck_channels, kernel_size=1),
+            nn.Conv2d(self.bottleneck_channels, self.conv_channels, kernel_size=3, padding=1),
+            nn.Conv2d(self.conv_channels, self.bottleneck_channels, kernel_size=1),
+            # nn.BatchNorm2d(self.conv_channels),
+            nn.GroupNorm(1, self.bottleneck_channels),
+            # nn.ReLU(inplace=True),
+            nn.LeakyReLU(inplace=True),
+        )
+
+        self.conv2_seq = nn.Sequential(
+            nn.Conv2d(self.input_channels + self.bottleneck_channels, self.bottleneck_channels, kernel_size=1),
+            nn.Conv2d(self.bottleneck_channels, self.conv_channels, kernel_size=3, padding=1),
+            nn.Conv2d(self.conv_channels, self.final_channels, kernel_size=1),
+            # nn.BatchNorm2d(self.conv_channels),
+            nn.GroupNorm(1, self.final_channels),
+            # nn.ReLU(inplace=True),
+            nn.LeakyReLU(inplace=True),
+        )
+
+    def forward(self, input_tensor):
+        conv1_tensor = self.conv1_seq(input_tensor)
+        conv2_tensor = self.conv2_seq(torch.cat([input_tensor, conv1_tensor], dim=1))
+
+        return conv2_tensor
+
+
+class SegmentationModel(nn.Module):
+    def __init__(self, depth, in_channels, tail_channels=None, out_channels=None, final_channels=None):
+        super().__init__()
+        self.depth = depth
+
+        # self.in_size = in_size
+        # self.tailOut_size = in_size #self.in_size - 4
+        # self.headIn_size = in_size #None
+        # self.out_size = in_size #None
+
+        self.in_channels = in_channels
+        self.tailOut_channels = tail_channels or in_channels * 2
+        self.headIn_channels = None
+        self.out_channels = out_channels or self.tailOut_channels
+        self.final_channels = final_channels
+
+        # assert in_size % 2 == 0, repr([in_size, depth])
+
+        self.tail_seq = nn.Sequential(
+            nn.ReplicationPad3d(2),
+            nn.Conv3d(self.in_channels, self.tailOut_channels, 3),
+            nn.GroupNorm(1, self.tailOut_channels),
+            nn.ReLU(inplace=True),
+            nn.Conv3d(self.tailOut_channels, self.tailOut_channels, 3),
+            nn.GroupNorm(1, self.tailOut_channels),
+            nn.ReLU(inplace=True),
+        )
+
+        if depth:
+            self.downsample_layer = nn.MaxPool3d(kernel_size=2, stride=2)
+            self.child_layer = SegmentationModel(depth - 1, self.tailOut_channels)
+
+            self.headIn_channels = self.in_channels + self.tailOut_channels + self.child_layer.out_channels
+            # self.headIn_size = self.child_layer.out_size * 2
+            # self.out_size = self.headIn_size #- 4
+
+            # self.upsample_layer = nn.Upsample(scale_factor=2, mode='trilinear')
+        else:
+            self.downsample_layer = None
+            self.child_layer = None
+            # self.upsample_layer = None
+
+            self.headIn_channels = self.in_channels + self.tailOut_channels
+            # self.headIn_size = self.tailOut_size
+            # self.out_size = self.headIn_size #- 4
+
+        self.head_seq = nn.Sequential(
+            nn.ReplicationPad3d(2),
+            nn.Conv3d(self.headIn_channels, self.out_channels, 3),
+            nn.GroupNorm(1, self.out_channels),
+            nn.ReLU(inplace=True),
+            nn.Conv3d(self.out_channels, self.out_channels, 3),
+            nn.GroupNorm(1, self.out_channels),
+            nn.ReLU(inplace=True),
+        )
+
+        if self.final_channels:
+            self.final_seq = nn.Sequential(
+                nn.ReplicationPad3d(1),
+                nn.Conv3d(self.out_channels, self.final_channels, 1),
+            )
+        else:
+            self.final_seq = None
+
+    def forward(self, in_data):
+
+        assert in_data.is_contiguous()
+
+        try:
+            tail_out = self.tail_seq(in_data)
+        except:
+            log.debug([in_data.size()])
+            raise
+
+        if self.downsample_layer:
+            down_out = self.downsample_layer(tail_out)
+            child_out = self.child_layer(down_out)
+            # up_out = self.upsample_layer(child_out)
+
+            up_out = nn.functional.interpolate(child_out, scale_factor=2, mode='trilinear')
+
+            # crop_int = (tail_out.size(-1) - up_out.size(-1)) // 2
+            # crop_out = tail_out[:, :, crop_int:-crop_int, crop_int:-crop_int, crop_int:-crop_int]
+            # combined_out = torch.cat([crop_out, up_out], 1)
+
+            combined_out = torch.cat([in_data, tail_out, up_out], 1)
+        else:
+            combined_out = torch.cat([in_data, tail_out], 1)
+
+        head_out = self.head_seq(combined_out)
+
+        if self.final_seq:
+            final_out = self.final_seq(head_out)
+            return final_out
+        else:
+            return head_out
+
+
+class DenseSegmentationModel(nn.Module):
+    def __init__(self, depth, in_channels, conv_channels, final_channels=None):
+        super().__init__()
+        self.depth = depth
+
+        self.in_channels = in_channels
+        self.conv_channels = conv_channels
+        self.final_channels = final_channels
+
+        self.convA_seq = nn.Sequential(
+            nn.Conv3d(self.in_channels, self.conv_channels // 4, 1),
+            nn.ReplicationPad3d(1),
+            nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
+            nn.BatchNorm3d(self.conv_channels),
+            nn.ReLU(inplace=True),
+        )
+
+        self.convB_seq = nn.Sequential(
+            nn.Conv3d(self.in_channels + self.conv_channels, self.conv_channels // 4, 1),
+            nn.ReplicationPad3d(1),
+            nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
+            nn.BatchNorm3d(self.conv_channels),
+            nn.ReLU(inplace=True),
+        )
+
+        if self.depth:
+            self.downsample_layer = nn.MaxPool3d(kernel_size=2, stride=2)
+            self.child_layer = SegmentationModel(depth - 1, self.conv_channels, self.conv_channels * 2)
+            self.upsample_layer = nn.Upsample(scale_factor=2, mode='trilinear')
+
+            self.convC_seq = nn.Sequential(
+                nn.Conv3d(self.in_channels + self.conv_channels * 3, self.conv_channels // 4, 1),
+                nn.ReplicationPad3d(1),
+                nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
+                nn.BatchNorm3d(self.conv_channels),
+                nn.ReLU(inplace=True),
+            )
+        else:
+            self.downsample_layer = None
+            self.child_layer = None
+            self.upsample_layer = None
+
+            self.convC_seq = nn.Sequential(
+                nn.Conv3d(self.in_channels + self.conv_channels, self.conv_channels // 4, 1),
+                nn.ReplicationPad3d(1),
+                nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
+                nn.BatchNorm3d(self.conv_channels),
+                nn.ReLU(inplace=True),
+            )
+
+        self.convD_seq = nn.Sequential(
+            nn.Conv3d(self.in_channels + self.conv_channels, self.conv_channels // 4, 1),
+            nn.ReplicationPad3d(1),
+            nn.Conv3d(self.conv_channels // 4, self.conv_channels, 3),
+            nn.BatchNorm3d(self.conv_channels),
+            nn.ReLU(inplace=True),
+        )
+
+        if self.final_channels:
+            self.final_seq = nn.Sequential(
+                # nn.ReplicationPad3d(1),
+                nn.Conv3d(self.conv_channels, self.final_channels, 1),
+            )
+        else:
+            self.final_seq = None
+
+    def forward(self, data_in):
+        a_out = self.convA_seq(data_in)
+        b_out = self.convB_seq(torch.cat([data_in, a_out], 1))
+
+        if self.downsample_layer:
+            down_out = self.downsample_layer(b_out)
+            child_out = self.child_layer(down_out)
+            up_out = self.upsample_layer(child_out)
+
+            c_out = self.convC_seq(torch.cat([data_in, b_out, up_out], 1))
+        else:
+            c_out = self.convC_seq(torch.cat([data_in, b_out], 1))
+
+        d_out = self.convD_seq(torch.cat([data_in, c_out], 1))
+
+        if self.final_seq:
+            return self.final_seq(d_out)
+        else:
+            return d_out

+ 68 - 0
p2ch13/prepcache.py

@@ -0,0 +1,68 @@
+import argparse
+import sys
+
+import numpy as np
+
+import torch.nn as nn
+from torch.autograd import Variable
+from torch.optim import SGD
+from torch.utils.data import DataLoader
+
+from util.util import enumerateWithEstimate
+from .dsets import PrepcacheLunaDataset, getCtSampleSize
+from util.logconf import logging
+# from .model import LunaModel
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+log.setLevel(logging.INFO)
+# log.setLevel(logging.DEBUG)
+
+
+class LunaPrepCacheApp(object):
+    @classmethod
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=1024,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+        # parser.add_argument('--scaled',
+        #     help="Scale the CT chunks to square voxels.",
+        #     default=False,
+        #     action='store_true',
+        # )
+
+        self.cli_args = parser.parse_args(sys_argv)
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        self.prep_dl = DataLoader(
+            PrepcacheLunaDataset(
+                sortby_str='series_uid',
+            ),
+            batch_size=self.cli_args.batch_size,
+            num_workers=self.cli_args.num_workers,
+        )
+
+        batch_iter = enumerateWithEstimate(
+            self.prep_dl,
+            "Stuffing cache",
+            start_ndx=self.prep_dl.num_workers,
+        )
+        for batch_ndx, batch_tup in batch_iter:
+            pass
+
+
+if __name__ == '__main__':
+    LunaPrepCacheApp().main()

+ 92 - 0
p2ch13/screencts.py

@@ -0,0 +1,92 @@
+import argparse
+import sys
+
+import numpy as np
+
+import torch.nn as nn
+from torch.autograd import Variable
+from torch.optim import SGD
+from torch.utils.data import Dataset, DataLoader
+
+from util.util import enumerateWithEstimate, prhist
+from .dsets import getNoduleInfoList, getCtSize, getCt
+from util.logconf import logging
+# from .model import LunaModel
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+log.setLevel(logging.INFO)
+# log.setLevel(logging.DEBUG)
+
+
+class LunaScreenCtDataset(Dataset):
+    def __init__(self):
+        self.series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
+
+    def __len__(self):
+        return len(self.series_list)
+
+    def __getitem__(self, ndx):
+        series_uid = self.series_list[ndx]
+        ct = getCt(series_uid)
+        mid_ndx = ct.hu_a.shape[0] // 2
+
+        air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask, altben_mask = ct.build2dLungMask(mid_ndx)
+
+        return series_uid, float(dense_mask.sum() / denoise_mask.sum())
+
+
+class LunaScreenCtApp(object):
+    @classmethod
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=4,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+        # parser.add_argument('--scaled',
+        #     help="Scale the CT chunks to square voxels.",
+        #     default=False,
+        #     action='store_true',
+        # )
+
+        self.cli_args = parser.parse_args(sys_argv)
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        self.prep_dl = DataLoader(
+            LunaScreenCtDataset(),
+            batch_size=self.cli_args.batch_size,
+            num_workers=self.cli_args.num_workers,
+        )
+
+        series2ratio_dict = {}
+
+        batch_iter = enumerateWithEstimate(
+            self.prep_dl,
+            "Screening CTs",
+            start_ndx=self.prep_dl.num_workers,
+        )
+        for batch_ndx, batch_tup in batch_iter:
+            series_list, ratio_list = batch_tup
+            for series_uid, ratio_float in zip(series_list, ratio_list):
+                series2ratio_dict[series_uid] = ratio_float
+            # break
+
+        prhist(list(series2ratio_dict.values()))
+
+
+
+
+if __name__ == '__main__':
+    sys.exit(LunaScreenCtApp().main() or 0)

+ 459 - 0
p2ch13/train_cls.py

@@ -0,0 +1,459 @@
+import argparse
+import datetime
+import os
+import sys
+
+import numpy as np
+
+from torch.utils.tensorboard import SummaryWriter
+
+import torch
+import torch.nn as nn
+from torch.optim import SGD
+from torch.utils.data import DataLoader
+
+from util.util import enumerateWithEstimate
+from .dsets import LunaDataset
+from .model_cls import LunaModel
+
+from util.logconf import logging
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+log.setLevel(logging.INFO)
+# log.setLevel(logging.DEBUG)
+
+# Used for computeBatchLoss and logMetrics to index into metrics_t/metrics_a
+METRICS_LABEL_NDX=0
+METRICS_PRED_NDX=1
+METRICS_LOSS_NDX=2
+METRICS_SIZE = 3
+
+class LunaTrainingApp(object):
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=32,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+        parser.add_argument('--epochs',
+            help='Number of epochs to train for',
+            default=1,
+            type=int,
+        )
+        parser.add_argument('--balanced',
+            help="Balance the training data to half benign, half malignant.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augmented',
+            help="Augment the training data.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-flip',
+            help="Augment the training data by randomly flipping the data left-right, up-down, and front-back.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-offset',
+            help="Augment the training data by randomly offsetting the data slightly along the X and Y axes.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-scale',
+            help="Augment the training data by randomly increasing or decreasing the size of the nodule.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-rotate',
+            help="Augment the training data by randomly rotating the data around the head-foot axis.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-noise',
+            help="Augment the training data by randomly adding noise to the data.",
+            action='store_true',
+            default=False,
+        )
+
+        parser.add_argument('--tb-prefix',
+            default='p2ch13',
+            help="Data prefix to use for Tensorboard run. Defaults to chapter.",
+        )
+        parser.add_argument('comment',
+            help="Comment suffix for Tensorboard run.",
+            nargs='?',
+            default='dlwpt',
+        )
+
+        self.cli_args = parser.parse_args(sys_argv)
+        self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
+
+        self.trn_writer = None
+        self.val_writer = None
+        self.totalTrainingSamples_count = 0
+
+        self.augmentation_dict = {}
+        if self.cli_args.augmented or self.cli_args.augment_flip:
+            self.augmentation_dict['flip'] = True
+        if self.cli_args.augmented or self.cli_args.augment_offset:
+            self.augmentation_dict['offset'] = 0.1
+        if self.cli_args.augmented or self.cli_args.augment_scale:
+            self.augmentation_dict['scale'] = 0.2
+        if self.cli_args.augmented or self.cli_args.augment_rotate:
+            self.augmentation_dict['rotate'] = True
+        if self.cli_args.augmented or self.cli_args.augment_noise:
+            self.augmentation_dict['noise'] = 25.0
+
+        self.use_cuda = torch.cuda.is_available()
+        self.device = torch.device("cuda" if self.use_cuda else "cpu")
+
+        self.model = self.initModel()
+        self.optimizer = self.initOptimizer()
+
+
+    def initModel(self):
+        model = LunaModel()
+        if self.use_cuda:
+            log.info("Using CUDA with {} devices.".format(torch.cuda.device_count()))
+            if torch.cuda.device_count() > 1:
+                model = nn.DataParallel(model)
+            model = model.to(self.device)
+        return model
+
+    def initOptimizer(self):
+        return SGD(self.model.parameters(), lr=0.001, momentum=0.99)
+        # return Adam(self.model.parameters())
+
+    def initTrainDl(self):
+        train_ds = LunaDataset(
+            val_stride=10,
+            isValSet_bool=False,
+            ratio_int=int(self.cli_args.balanced),
+            augmentation_dict=self.augmentation_dict,
+        )
+
+        train_dl = DataLoader(
+            train_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return train_dl
+
+    def initValDl(self):
+        val_ds = LunaDataset(
+            val_stride=10,
+            isValSet_bool=True,
+        )
+
+        val_dl = DataLoader(
+            val_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return val_dl
+
+    def initTensorboardWriters(self):
+        if self.trn_writer is None:
+            log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
+
+            self.trn_writer = SummaryWriter(log_dir=log_dir + '-trn_cls-' + self.cli_args.comment)
+            self.val_writer = SummaryWriter(log_dir=log_dir + '-val_cls-' + self.cli_args.comment)
+
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        train_dl = self.initTrainDl()
+        val_dl = self.initValDl()
+
+        best_score = 0.0
+
+        for epoch_ndx in range(1, self.cli_args.epochs + 1):
+
+            log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
+                epoch_ndx,
+                self.cli_args.epochs,
+                len(train_dl),
+                len(val_dl),
+                self.cli_args.batch_size,
+                (torch.cuda.device_count() if self.use_cuda else 1),
+            ))
+
+            trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
+            self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)
+
+            valMetrics_t = self.doValidation(epoch_ndx, val_dl)
+            score = self.logMetrics(epoch_ndx, 'val', valMetrics_t)
+            best_score = max(score, best_score)
+
+            self.saveModel('cls', epoch_ndx, score == best_score)
+
+        if hasattr(self, 'trn_writer'):
+            self.trn_writer.close()
+            self.val_writer.close()
+
+
+    def doTraining(self, epoch_ndx, train_dl):
+        self.model.train()
+        train_dl.dataset.shuffleSamples()
+        trnMetrics_g = torch.zeros(
+                METRICS_SIZE,
+                len(train_dl.dataset),
+            ).to(self.device)
+        batch_iter = enumerateWithEstimate(
+            train_dl,
+            "E{} Training".format(epoch_ndx),
+            start_ndx=train_dl.num_workers,
+        )
+        for batch_ndx, batch_tup in batch_iter:
+            self.optimizer.zero_grad()
+
+            loss_var = self.computeBatchLoss(
+                batch_ndx,
+                batch_tup,
+                train_dl.batch_size,
+                trnMetrics_g
+            )
+
+            loss_var.backward()
+            self.optimizer.step()
+            del loss_var
+
+        self.totalTrainingSamples_count += len(train_dl.dataset)
+
+        return trnMetrics_g.to('cpu')
+
+
+    def doValidation(self, epoch_ndx, val_dl):
+        with torch.no_grad():
+            self.model.eval()
+            valMetrics_g = torch.zeros(
+                    METRICS_SIZE,
+                    len(val_dl.dataset),
+                ).to(self.device)
+            batch_iter = enumerateWithEstimate(
+                val_dl,
+                "E{} Validation ".format(epoch_ndx),
+                start_ndx=val_dl.num_workers,
+            )
+            for batch_ndx, batch_tup in batch_iter:
+                self.computeBatchLoss(
+                    batch_ndx,
+                    batch_tup,
+                    val_dl.batch_size,
+                    valMetrics_g,
+                )
+
+        return valMetrics_g.to('cpu')
+
+
+
+    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g):
+        input_t, label_t, _series_list, _center_list = batch_tup
+
+        input_g = input_t.to(self.device, non_blocking=True)
+        label_g = label_t.to(self.device, non_blocking=True)
+
+        logits_g, probability_g = self.model(input_g)
+
+        loss_func = nn.CrossEntropyLoss(reduction='none')
+        loss_g = loss_func(
+            logits_g,
+            label_g[:,1],
+        )
+        start_ndx = batch_ndx * batch_size
+        end_ndx = start_ndx + label_t.size(0)
+
+        metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_g[:,1]
+        metrics_g[METRICS_PRED_NDX, start_ndx:end_ndx] = probability_g[:,1]
+        metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_g
+
+        return loss_g.mean()
+
+
+    def logMetrics(
+            self,
+            epoch_ndx,
+            mode_str,
+            metrics_t,
+    ):
+        self.initTensorboardWriters()
+        log.info("E{} {}".format(
+            epoch_ndx,
+            type(self).__name__,
+        ))
+
+        metrics_t = metrics_t.detach().numpy()
+
+        benLabel_mask = metrics_t[METRICS_LABEL_NDX] <= 0.5
+        benPred_mask = metrics_t[METRICS_PRED_NDX] <= 0.5
+
+        malLabel_mask = ~benLabel_mask
+        malPred_mask = ~benPred_mask
+
+        ben_count = benLabel_mask.sum()
+        mal_count = malLabel_mask.sum()
+
+        trueNeg_count = ben_correct = (benLabel_mask & benPred_mask).sum()
+        truePos_count = mal_correct = (malLabel_mask & malPred_mask).sum()
+
+        falsePos_count = ben_count - ben_correct
+        falseNeg_count = mal_count - mal_correct
+
+        metrics_dict = {}
+        metrics_dict['loss/all'] = metrics_t[METRICS_LOSS_NDX].mean()
+        metrics_dict['loss/ben'] = metrics_t[METRICS_LOSS_NDX, benLabel_mask].mean()
+        metrics_dict['loss/mal'] = metrics_t[METRICS_LOSS_NDX, malLabel_mask].mean()
+
+        metrics_dict['correct/all'] = (mal_correct + ben_correct) / metrics_t.shape[1] * 100
+        metrics_dict['correct/ben'] = (ben_correct) / ben_count * 100
+        metrics_dict['correct/mal'] = (mal_correct) / mal_count * 100
+
+        precision = metrics_dict['pr/precision'] = \
+            truePos_count / (truePos_count + falsePos_count)
+        recall = metrics_dict['pr/recall'] = \
+            truePos_count / (truePos_count + falseNeg_count)
+
+        metrics_dict['pr/f1_score'] = \
+            2 * (precision * recall) / (precision + recall)
+
+        log.info(
+            ("E{} {:8} {loss/all:.4f} loss, "
+                 + "{correct/all:-5.1f}% correct, "
+                 + "{pr/precision:.4f} precision, "
+                 + "{pr/recall:.4f} recall, "
+                 + "{pr/f1_score:.4f} f1 score"
+            ).format(
+                epoch_ndx,
+                mode_str,
+                **metrics_dict,
+            )
+        )
+        log.info(
+            ("E{} {:8} {loss/ben:.4f} loss, "
+                 + "{correct/ben:-5.1f}% correct ({ben_correct:} of {ben_count:})"
+            ).format(
+                epoch_ndx,
+                mode_str + '_ben',
+                ben_correct=ben_correct,
+                ben_count=ben_count,
+                **metrics_dict,
+            )
+        )
+        log.info(
+            ("E{} {:8} {loss/mal:.4f} loss, "
+                 + "{correct/mal:-5.1f}% correct ({mal_correct:} of {mal_count:})"
+            ).format(
+                epoch_ndx,
+                mode_str + '_mal',
+                mal_correct=mal_correct,
+                mal_count=mal_count,
+                **metrics_dict,
+            )
+        )
+        writer = getattr(self, mode_str + '_writer')
+
+        for key, value in metrics_dict.items():
+            writer.add_scalar(key, value, self.totalTrainingSamples_count)
+
+        writer.add_pr_curve(
+            'pr',
+            metrics_t[METRICS_LABEL_NDX],
+            metrics_t[METRICS_PRED_NDX],
+            self.totalTrainingSamples_count,
+        )
+
+        bins = [x/50.0 for x in range(51)]
+
+        benHist_mask = benLabel_mask & (metrics_t[METRICS_PRED_NDX] > 0.01)
+        malHist_mask = malLabel_mask & (metrics_t[METRICS_PRED_NDX] < 0.99)
+
+        if benHist_mask.any():
+            writer.add_histogram(
+                'is_ben',
+                metrics_t[METRICS_PRED_NDX, benHist_mask],
+                self.totalTrainingSamples_count,
+                bins=bins,
+            )
+        if malHist_mask.any():
+            writer.add_histogram(
+                'is_mal',
+                metrics_t[METRICS_PRED_NDX, malHist_mask],
+                self.totalTrainingSamples_count,
+                bins=bins,
+            )
+
+        score = 1 \
+            + metrics_dict['pr/f1_score'] \
+            - metrics_dict['loss/mal'] * 0.01 \
+            - metrics_dict['loss/all'] * 0.0001
+
+        return score
+
+    def saveModel(self, type_str, epoch_ndx, isBest=False):
+        file_path = os.path.join(
+            'data-unversioned',
+            'part2',
+            'models',
+            self.cli_args.tb_prefix,
+            '{}_{}_{}.{}.state'.format(
+                type_str,
+                self.time_str,
+                self.cli_args.comment,
+                self.totalTrainingSamples_count,
+            )
+        )
+
+        os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
+
+        model = self.model
+        if hasattr(model, 'module'):
+            model = model.module
+
+        state = {
+            'model_state': model.state_dict(),
+            'model_name': type(model).__name__,
+            'optimizer_state' : self.optimizer.state_dict(),
+            'optimizer_name': type(self.optimizer).__name__,
+            'epoch': epoch_ndx,
+            'totalTrainingSamples_count': self.totalTrainingSamples_count,
+            # 'resumed_from': self.cli_args.resume,
+        }
+        torch.save(state, file_path)
+
+        log.debug("Saved model params to {}".format(file_path))
+
+        if isBest:
+            file_path = os.path.join(
+                'data-unversioned',
+                'part2',
+                'models',
+                self.cli_args.tb_prefix,
+                '{}_{}_{}.{}.state'.format(
+                    type_str,
+                    self.time_str,
+                    self.cli_args.comment,
+                    'best',
+                )
+            )
+            torch.save(state, file_path)
+
+            log.debug("Saved model params to {}".format(file_path))
+
+if __name__ == '__main__':
+    LunaTrainingApp().main()

+ 580 - 0
p2ch13/train_seg.py

@@ -0,0 +1,580 @@
+import argparse
+import datetime
+import os
+import socket
+import sys
+
+import numpy as np
+from torch.utils.tensorboard import SummaryWriter
+
+import torch
+import torch.nn as nn
+import torch.optim
+
+from torch.optim import SGD, Adam
+from torch.utils.data import DataLoader
+
+from util.util import enumerateWithEstimate
+from .dsets import Luna2dSegmentationDataset, TrainingLuna2dSegmentationDataset, getCt
+from util.logconf import logging
+from util.util import xyz2irc
+from .model_seg import UNetWrapper
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+# Used for computeClassificationLoss and logMetrics to index into metrics_t/metrics_a
+METRICS_LABEL_NDX = 0
+METRICS_LOSS_NDX = 1
+METRICS_MAL_LOSS_NDX = 2
+# METRICS_ALL_LOSS_NDX = 3
+
+METRICS_MTP_NDX = 4
+METRICS_MFN_NDX = 5
+METRICS_MFP_NDX = 6
+METRICS_ATP_NDX = 7
+METRICS_AFN_NDX = 8
+METRICS_AFP_NDX = 9
+
+METRICS_SIZE = 10
+
+class LunaTrainingApp(object):
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=16,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+        parser.add_argument('--epochs',
+            help='Number of epochs to train for',
+            default=1,
+            type=int,
+        )
+
+        parser.add_argument('--augmented',
+            help="Augment the training data.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-flip',
+            help="Augment the training data by randomly flipping the data left-right, up-down, and front-back.",
+            action='store_true',
+            default=False,
+        )
+        # parser.add_argument('--augment-offset',
+        #     help="Augment the training data by randomly offsetting the data slightly along the X and Y axes.",
+        #     action='store_true',
+        #     default=False,
+        # )
+        # parser.add_argument('--augment-scale',
+        #     help="Augment the training data by randomly increasing or decreasing the size of the nodule.",
+        #     action='store_true',
+        #     default=False,
+        # )
+        parser.add_argument('--augment-rotate',
+            help="Augment the training data by randomly rotating the data around the head-foot axis.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-noise',
+            help="Augment the training data by randomly adding noise to the data.",
+            action='store_true',
+            default=False,
+        )
+
+        parser.add_argument('--tb-prefix',
+            default='p2ch13',
+            help="Data prefix to use for Tensorboard run. Defaults to chapter.",
+        )
+
+        parser.add_argument('comment',
+            help="Comment suffix for Tensorboard run.",
+            nargs='?',
+            default='none',
+        )
+
+        self.cli_args = parser.parse_args(sys_argv)
+        self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
+        self.totalTrainingSamples_count = 0
+        self.trn_writer = None
+        self.val_writer = None
+
+        augmentation_dict = {}
+        if self.cli_args.augmented or self.cli_args.augment_flip:
+            augmentation_dict['flip'] = True
+        if self.cli_args.augmented or self.cli_args.augment_rotate:
+            augmentation_dict['rotate'] = True
+        if self.cli_args.augmented or self.cli_args.augment_noise:
+            augmentation_dict['noise'] = 0.025
+        self.augmentation_dict = augmentation_dict
+
+        self.use_cuda = torch.cuda.is_available()
+        self.device = torch.device("cuda" if self.use_cuda else "cpu")
+
+        self.model = self.initModel()
+        self.optimizer = self.initOptimizer()
+
+
+    def initModel(self):
+        model = UNetWrapper(
+            in_channels=8,
+            n_classes=1,
+            depth=4,
+            wf=3,
+            padding=True,
+            batch_norm=True,
+            up_mode='upconv',
+        )
+
+        if self.use_cuda:
+            if torch.cuda.device_count() > 1:
+                model = nn.DataParallel(model)
+            model = model.to(self.device)
+        return model
+
+    def initOptimizer(self):
+        return SGD(self.model.parameters(), lr=0.001, momentum=0.99)
+        # return Adam(self.model.parameters())
+
+
+    def initTrainDl(self):
+        train_ds = TrainingLuna2dSegmentationDataset(
+            val_stride=10,
+            isValSet_bool=False,
+            contextSlices_count=3,
+            augmentation_dict=self.augmentation_dict,
+        )
+
+        train_dl = DataLoader(
+            train_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return train_dl
+
+    def initValDl(self):
+        val_ds = Luna2dSegmentationDataset(
+            val_stride=10,
+            isValSet_bool=True,
+            contextSlices_count=3,
+        )
+
+        val_dl = DataLoader(
+            val_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return val_dl
+
+    def initTensorboardWriters(self):
+        if self.trn_writer is None:
+            log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
+
+            self.trn_writer = SummaryWriter(log_dir=log_dir + '_trn_seg_' + self.cli_args.comment)
+            self.val_writer = SummaryWriter(log_dir=log_dir + '_val_seg_' + self.cli_args.comment)
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        train_dl = self.initTrainDl()
+        val_dl = self.initValDl()
+
+        # self.logModelMetrics(self.model)
+
+        best_score = 0.0
+        for epoch_ndx in range(1, self.cli_args.epochs + 1):
+            log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
+                epoch_ndx,
+                self.cli_args.epochs,
+                len(train_dl),
+                len(val_dl),
+                self.cli_args.batch_size,
+                (torch.cuda.device_count() if self.use_cuda else 1),
+            ))
+
+            trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
+            self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)
+            self.logImages(epoch_ndx, 'trn', train_dl)
+            self.logImages(epoch_ndx, 'val', val_dl)
+            # self.logModelMetrics(self.model)
+
+            valMetrics_t = self.doValidation(epoch_ndx, val_dl)
+            score = self.logMetrics(epoch_ndx, 'val', valMetrics_t)
+            best_score = max(score, best_score)
+
+            self.saveModel('seg', epoch_ndx, score == best_score)
+
+        if hasattr(self, 'trn_writer'):
+            self.trn_writer.close()
+            self.val_writer.close()
+
+    def doTraining(self, epoch_ndx, train_dl):
+        trnMetrics_g = torch.zeros(METRICS_SIZE, len(train_dl.dataset)).to(self.device)
+        self.model.train()
+
+        batch_iter = enumerateWithEstimate(
+            train_dl,
+            "E{} Training".format(epoch_ndx),
+            start_ndx=train_dl.num_workers,
+        )
+        for batch_ndx, batch_tup in batch_iter:
+            self.optimizer.zero_grad()
+
+            loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trnMetrics_g)
+            loss_var.backward()
+
+            self.optimizer.step()
+            del loss_var
+
+        self.totalTrainingSamples_count += trnMetrics_g.size(1)
+
+        return trnMetrics_g.to('cpu')
+
+    def doValidation(self, epoch_ndx, val_dl):
+        with torch.no_grad():
+            valMetrics_g = torch.zeros(METRICS_SIZE, len(val_dl.dataset)).to(self.device)
+            self.model.eval()
+
+            batch_iter = enumerateWithEstimate(
+                val_dl,
+                "E{} Validation ".format(epoch_ndx),
+                start_ndx=val_dl.num_workers,
+            )
+            for batch_ndx, batch_tup in batch_iter:
+                self.computeBatchLoss(batch_ndx, batch_tup, val_dl.batch_size, valMetrics_g)
+
+        return valMetrics_g.to('cpu')
+
+    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g):
+        input_t, label_t, label_list, ben_t, mal_t, _, _ = batch_tup
+
+        input_g = input_t.to(self.device, non_blocking=True)
+        label_g = label_t.to(self.device, non_blocking=True)
+        mal_g = mal_t.to(self.device, non_blocking=True)
+        ben_g = ben_t.to(self.device, non_blocking=True)
+
+        start_ndx = batch_ndx * batch_size
+        end_ndx = start_ndx + label_t.size(0)
+        intersectionSum = lambda a, b: (a * b).view(a.size(0), -1).sum(dim=1)
+
+        prediction_g = self.model(input_g)
+        diceLoss_g = self.diceLoss(label_g, prediction_g)
+
+        with torch.no_grad():
+            malLoss_g = self.diceLoss(mal_g, prediction_g * mal_g, p=True)
+            predictionBool_g = (prediction_g > 0.5).to(torch.float32)
+
+            metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_list
+            metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_g
+            metrics_g[METRICS_MAL_LOSS_NDX, start_ndx:end_ndx] = malLoss_g
+
+            malPred_g = predictionBool_g * mal_g
+            tp = intersectionSum(    mal_g,       malPred_g)
+            fn = intersectionSum(    mal_g,   1 - malPred_g)
+
+            metrics_g[METRICS_MTP_NDX, start_ndx:end_ndx] = tp
+            metrics_g[METRICS_MFN_NDX, start_ndx:end_ndx] = fn
+
+            del malPred_g, tp, fn
+
+            tp = intersectionSum(    label_g,     predictionBool_g)
+            fn = intersectionSum(    label_g, 1 - predictionBool_g)
+            fp = intersectionSum(1 - label_g,     predictionBool_g)
+
+            metrics_g[METRICS_ATP_NDX, start_ndx:end_ndx] = tp
+            metrics_g[METRICS_AFN_NDX, start_ndx:end_ndx] = fn
+            metrics_g[METRICS_AFP_NDX, start_ndx:end_ndx] = fp
+
+            del tp, fn, fp
+
+        return diceLoss_g.mean()
+
+    # def diceLoss(self, label_g, prediction_g, epsilon=0.01, p=False):
+    def diceLoss(self, label_g, prediction_g, epsilon=1, p=False):
+        sum_dim1 = lambda t: t.view(t.size(0), -1).sum(dim=1)
+
+        diceLabel_g = sum_dim1(label_g)
+        dicePrediction_g = sum_dim1(prediction_g)
+        diceCorrect_g = sum_dim1(prediction_g * label_g)
+
+        epsilon_g = torch.ones_like(diceCorrect_g) * epsilon
+        diceLoss_g = 1 - (2 * diceCorrect_g + epsilon_g) \
+            / (dicePrediction_g + diceLabel_g + epsilon_g)
+
+        if p and diceLoss_g.mean() < 0:
+            correct_tmp = prediction_g * label_g
+
+            log.debug([])
+            log.debug(['diceCorrect_g   ', diceCorrect_g[0].item(), correct_tmp[0].min().item(), correct_tmp[0].mean().item(), correct_tmp[0].max().item(), correct_tmp.shape])
+            log.debug(['dicePrediction_g', dicePrediction_g[0].item(), prediction_g[0].min().item(), prediction_g[0].mean().item(), prediction_g[0].max().item(), prediction_g.shape])
+            log.debug(['diceLabel_g     ', diceLabel_g[0].item(), label_g[0].min().item(), label_g[0].mean().item(), label_g[0].max().item(), label_g.shape])
+            log.debug(['2*diceCorrect_g ', 2 * diceCorrect_g[0].item()])
+            log.debug(['Prediction + Label      ', dicePrediction_g[0].item()])
+            log.debug(['diceLoss_g      ', diceLoss_g[0].item()])
+            assert False
+
+        return diceLoss_g
+
+
+    def logImages(self, epoch_ndx, mode_str, dl):
+        images_iter = sorted(dl.dataset.series_list)[:12]
+        for series_ndx, series_uid in enumerate(images_iter):
+            ct = getCt(series_uid)
+
+            for slice_ndx in range(6):
+                ct_ndx = slice_ndx * ct.hu_a.shape[0] // 5
+                ct_ndx = min(ct_ndx, ct.hu_a.shape[0] - 1)
+                sample_tup = dl.dataset[(series_uid, ct_ndx, False)]
+
+                ct_t, nodule_t, _, ben_t, mal_t, _, _ = sample_tup
+
+                ct_t[:-1,:,:] += 1
+                ct_t[:-1,:,:] /= 2
+
+                input_g = ct_t.to(self.device)
+                label_g = nodule_t.to(self.device)
+
+                prediction_g = self.model(input_g.unsqueeze(0))[0]
+                prediction_a = prediction_g.to('cpu').detach().numpy()
+                label_a = nodule_t.numpy()
+                ben_a = ben_t.numpy()
+                mal_a = mal_t.numpy()
+
+                ctSlice_a = ct_t[dl.dataset.contextSlices_count].numpy()
+
+                image_a = np.zeros((512, 512, 3), dtype=np.float32)
+                image_a[:,:,:] = ctSlice_a.reshape((512,512,1))
+                image_a[:,:,0] += prediction_a[0] * (1 - label_a[0])
+                image_a[:,:,1] += prediction_a[0] * mal_a[0]
+                image_a[:,:,2] += prediction_a[0] * ben_a[0]
+                image_a *= 0.5
+                image_a[image_a < 0] = 0
+                image_a[image_a > 1] = 1
+
+                writer = getattr(self, mode_str + '_writer')
+                writer.add_image(
+                    '{}/{}_prediction_{}'.format(
+                        mode_str,
+                        series_ndx,
+                        slice_ndx,
+                    ),
+                    image_a,
+                    self.totalTrainingSamples_count,
+                    dataformats='HWC',
+                )
+
+                # self.diceLoss(label_g, prediction_g, p=True)
+
+                if epoch_ndx == 1:
+                    image_a = np.zeros((512, 512, 3), dtype=np.float32)
+                    image_a[:,:,:] = ctSlice_a.reshape((512,512,1))
+                    image_a[:,:,0] += (1 - label_a[0]) * ct_t[-1].numpy() # Red
+                    image_a[:,:,1] += mal_a[0]  # Green
+                    image_a[:,:,2] += ben_a[0]  # Blue
+
+                    image_a *= 0.5
+                    image_a[image_a < 0] = 0
+                    image_a[image_a > 1] = 1
+                    writer.add_image(
+                        '{}/{}_label_{}'.format(
+                            mode_str,
+                            series_ndx,
+                            slice_ndx,
+                        ),
+                        image_a,
+                        self.totalTrainingSamples_count,
+                        dataformats='HWC',
+                    )
+
+
+    def logMetrics(self,
+        epoch_ndx,
+        mode_str,
+        metrics_t,
+    ):
+        log.info("E{} {}".format(
+            epoch_ndx,
+            type(self).__name__,
+        ))
+
+        metrics_a = metrics_t.detach().numpy()
+        sum_a = metrics_a.sum(axis=1)
+        assert np.isfinite(metrics_a).all()
+
+        malLabel_mask = (metrics_a[METRICS_LABEL_NDX] == 1) | (metrics_a[METRICS_LABEL_NDX] == 3)
+
+        # allLabel_mask = (metrics_a[METRICS_LABEL_NDX] == 2) | (metrics_a[METRICS_LABEL_NDX] == 3)
+
+        allLabel_count = sum_a[METRICS_ATP_NDX] + sum_a[METRICS_AFN_NDX]
+        malLabel_count = sum_a[METRICS_MTP_NDX] + sum_a[METRICS_MFN_NDX]
+
+        # allCorrect_count = sum_a[METRICS_ATP_NDX]
+        # malCorrect_count = sum_a[METRICS_MTP_NDX]
+#
+#             falsePos_count = allLabel_count - allCorrect_count
+#             falseNeg_count = malLabel_count - malCorrect_count
+
+
+        metrics_dict = {}
+        metrics_dict['loss/all'] = metrics_a[METRICS_LOSS_NDX].mean()
+        metrics_dict['loss/mal'] = np.nan_to_num(metrics_a[METRICS_MAL_LOSS_NDX, malLabel_mask].mean())
+        # metrics_dict['loss/all'] = metrics_a[METRICS_ALL_LOSS_NDX, allLabel_mask].mean()
+
+        # metrics_dict['correct/mal'] = sum_a[METRICS_MTP_NDX] / (sum_a[METRICS_MTP_NDX] + sum_a[METRICS_MFN_NDX]) * 100
+        # metrics_dict['correct/all'] = sum_a[METRICS_ATP_NDX] / (sum_a[METRICS_ATP_NDX] + sum_a[METRICS_AFN_NDX]) * 100
+
+        metrics_dict['percent_all/tp'] = sum_a[METRICS_ATP_NDX] / (allLabel_count or 1) * 100
+        metrics_dict['percent_all/fn'] = sum_a[METRICS_AFN_NDX] / (allLabel_count or 1) * 100
+        metrics_dict['percent_all/fp'] = sum_a[METRICS_AFP_NDX] / (allLabel_count or 1) * 100
+
+        metrics_dict['percent_mal/tp'] = sum_a[METRICS_MTP_NDX] / (malLabel_count or 1) * 100
+        metrics_dict['percent_mal/fn'] = sum_a[METRICS_MFN_NDX] / (malLabel_count or 1) * 100
+
+        precision = metrics_dict['pr/precision'] = sum_a[METRICS_ATP_NDX] \
+            / ((sum_a[METRICS_ATP_NDX] + sum_a[METRICS_AFP_NDX]) or 1)
+        recall    = metrics_dict['pr/recall']    = sum_a[METRICS_ATP_NDX] \
+            / ((sum_a[METRICS_ATP_NDX] + sum_a[METRICS_AFN_NDX]) or 1)
+
+        metrics_dict['pr/f1_score'] = 2 * (precision * recall) \
+            / ((precision + recall) or 1)
+
+        log.info(("E{} {:8} "
+                 + "{loss/all:.4f} loss, "
+                 + "{pr/precision:.4f} precision, "
+                 + "{pr/recall:.4f} recall, "
+                 + "{pr/f1_score:.4f} f1 score"
+                  ).format(
+            epoch_ndx,
+            mode_str,
+            **metrics_dict,
+        ))
+        log.info(("E{} {:8} "
+                  + "{loss/all:.4f} loss, "
+                  + "{percent_all/tp:-5.1f}% tp, {percent_all/fn:-5.1f}% fn, {percent_all/fp:-9.1f}% fp"
+                 # + "{correct/all:-5.1f}% correct ({allCorrect_count:} of {allLabel_count:})"
+        ).format(
+            epoch_ndx,
+            mode_str + '_all',
+            # allCorrect_count=allCorrect_count,
+            # allLabel_count=allLabel_count,
+            **metrics_dict,
+        ))
+
+        log.info(("E{} {:8} "
+                  + "{loss/mal:.4f} loss, "
+                  + "{percent_mal/tp:-5.1f}% tp, {percent_mal/fn:-5.1f}% fn"
+                 # + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
+        ).format(
+            epoch_ndx,
+            mode_str + '_mal',
+            # malCorrect_count=malCorrect_count,
+            # malLabel_count=malLabel_count,
+            **metrics_dict,
+        ))
+
+        self.initTensorboardWriters()
+        writer = getattr(self, mode_str + '_writer')
+
+        prefix_str = 'seg_'
+
+        for key, value in metrics_dict.items():
+            writer.add_scalar(prefix_str + key, value, self.totalTrainingSamples_count)
+
+        score = 1 \
+            - metrics_dict['loss/mal'] \
+            + metrics_dict['pr/f1_score'] \
+            - metrics_dict['pr/recall'] * 0.01 \
+            - metrics_dict['loss/all']  * 0.0001
+
+        return score
+
+    # def logModelMetrics(self, model):
+    #     writer = getattr(self, 'trn_writer')
+    #
+    #     model = getattr(model, 'module', model)
+    #
+    #     for name, param in model.named_parameters():
+    #         if param.requires_grad:
+    #             min_data = float(param.data.min())
+    #             max_data = float(param.data.max())
+    #             max_extent = max(abs(min_data), abs(max_data))
+    #
+    #             # bins = [x/50*max_extent for x in range(-50, 51)]
+    #
+    #             writer.add_histogram(
+    #                 name.rsplit('.', 1)[-1] + '/' + name,
+    #                 param.data.cpu().numpy(),
+    #                 # metrics_a[METRICS_PRED_NDX, benHist_mask],
+    #                 self.totalTrainingSamples_count,
+    #                 # bins=bins,
+    #             )
+    #
+    #             # print name, param.data
+
+    def saveModel(self, type_str, epoch_ndx, isBest=False):
+        file_path = os.path.join(
+            'data-unversioned',
+            'part2',
+            'models',
+            self.cli_args.tb_prefix,
+            '{}_{}_{}.{}.state'.format(
+                type_str,
+                self.time_str,
+                self.cli_args.comment,
+                self.totalTrainingSamples_count,
+            )
+        )
+
+        os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
+
+        model = self.model
+        if hasattr(model, 'module'):
+            model = model.module
+
+        state = {
+            'model_state': model.state_dict(),
+            'model_name': type(model).__name__,
+            'optimizer_state' : self.optimizer.state_dict(),
+            'optimizer_name': type(self.optimizer).__name__,
+            'epoch': epoch_ndx,
+            'totalTrainingSamples_count': self.totalTrainingSamples_count,
+        }
+        torch.save(state, file_path)
+
+        log.debug("Saved model params to {}".format(file_path))
+
+        if isBest:
+            file_path = os.path.join(
+                'data-unversioned',
+                'part2',
+                'models',
+                self.cli_args.tb_prefix,
+                '{}_{}_{}.{}.state'.format(
+                    type_str,
+                    self.time_str,
+                    self.cli_args.comment,
+                    'best',
+                )
+            )
+            torch.save(state, file_path)
+
+            log.debug("Saved model params to {}".format(file_path))
+
+
+if __name__ == '__main__':
+    LunaTrainingApp().main()

+ 702 - 0
p2ch13/training.py

@@ -0,0 +1,702 @@
+import argparse
+import datetime
+import os
+import socket
+import sys
+
+import numpy as np
+from tensorboardX import SummaryWriter
+
+import torch
+import torch.nn as nn
+import torch.optim
+
+from torch.optim import SGD, Adam
+from torch.utils.data import DataLoader
+
+from util.util import enumerateWithEstimate
+from .dsets import TrainingLuna2dSegmentationDataset, TestingLuna2dSegmentationDataset, LunaClassificationDataset, getCt
+from util.logconf import logging
+from util.util import xyz2irc
+from .model import UNetWrapper, LunaModel
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+# Used for computeClassificationLoss and logMetrics to index into metrics_tensor/metrics_ary
+# METRICS_LABEL_NDX=0
+# METRICS_PRED_NDX=1
+# METRICS_LOSS_NDX=2
+# METRICS_MAL_LOSS_NDX=3
+# METRICS_BEN_LOSS_NDX=4
+# METRICS_LUNG_LOSS_NDX=5
+# METRICS_MASKLOSS_NDX=2
+# METRICS_MALLOSS_NDX=3
+
+
+METRICS_LOSS_NDX = 0
+METRICS_LABEL_NDX = 1
+METRICS_PRED_NDX = 2
+
+METRICS_MTP_NDX = 3
+METRICS_MFN_NDX = 4
+METRICS_MFP_NDX = 5
+METRICS_BTP_NDX = 6
+METRICS_BFN_NDX = 7
+METRICS_BFP_NDX = 8
+
+METRICS_MAL_LOSS_NDX = 9
+METRICS_BEN_LOSS_NDX = 10
+
+# METRICS_MFOUND_NDX = 2
+
+# METRICS_MOK_NDX = 2
+
+# METRICS_FLG_LOSS_NDX = 10
+METRICS_SIZE = 11
+
+
+
+
+class LunaTrainingApp(object):
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=4,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+        parser.add_argument('--epochs',
+            help='Number of epochs to train for',
+            default=1,
+            type=int,
+        )
+        # parser.add_argument('--resume',
+        #     default=None,
+        #     help="File to resume training from.",
+        # )
+
+        parser.add_argument('--segmentation',
+            help="TODO", # TODO
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--balanced',
+            help="Balance the training data to half benign, half malignant.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--adaptive',
+            help="Balance the training data to start half benign, half malignant, and end at a 100:1 ratio.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--scaled',
+            help="Scale the CT chunks to square voxels.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--multiscaled',
+            help="Scale the CT chunks to square voxels.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augmented',
+            help="Augment the training data (implies --scaled).",
+            action='store_true',
+            default=False,
+        )
+
+        parser.add_argument('--tb-prefix',
+            default='p2ch10',
+            help="Data prefix to use for Tensorboard run. Defaults to chapter.",
+        )
+
+        parser.add_argument('comment',
+            help="Comment suffix for Tensorboard run.",
+            nargs='?',
+            default='none',
+        )
+
+        self.cli_args = parser.parse_args(sys_argv)
+        self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
+
+        self.trn_writer = None
+        self.tst_writer = None
+
+        self.use_cuda = torch.cuda.is_available()
+        self.device = torch.device("cuda" if self.use_cuda else "cpu")
+
+        if socket.gethostname() == 'c2':
+            self.device = torch.device("cuda:1") # TODO: remove me before print
+
+        self.model = self.initModel()
+        self.optimizer = self.initOptimizer()
+
+        self.totalTrainingSamples_count = 0
+
+
+
+    def initModel(self):
+        if self.cli_args.segmentation:
+            model = UNetWrapper(in_channels=8, n_classes=2, depth=5, wf=6, padding=True, batch_norm=True, up_mode='upconv')
+        else:
+            model = LunaModel()
+
+        if self.use_cuda:
+            if torch.cuda.device_count() > 1:
+                if socket.gethostname() == 'c2':
+                    model = nn.DataParallel(model, device_ids=[1, 0]) # TODO: remove me before print
+                else:
+                    model = nn.DataParallel(model)
+
+            model = model.to(self.device)
+
+
+        return model
+
+    def initOptimizer(self):
+
+        # self.optimizer = SGD(self.model.parameters(), lr=0.01, momentum=0.99)
+        return Adam(self.model.parameters())
+
+
+    def initTrainDl(self):
+        if self.cli_args.segmentation:
+            train_ds = TrainingLuna2dSegmentationDataset(
+                    test_stride=10,
+                    contextSlices_count=3,
+                )
+        else:
+            train_ds = LunaClassificationDataset(
+                 test_stride=10,
+                 isTestSet_bool=False,
+                 # series_uid=None,
+                 # sortby_str='random',
+                 ratio_int=int(self.cli_args.balanced),
+                 # scaled_bool=False,
+                 # multiscaled_bool=False,
+                 # augmented_bool=False,
+                 # noduleInfo_list=None,
+            )
+
+        train_dl = DataLoader(
+            train_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return train_dl
+
+    def initTestDl(self):
+        if self.cli_args.segmentation:
+            test_ds = TestingLuna2dSegmentationDataset(
+                    test_stride=10,
+                    contextSlices_count=3,
+                )
+        else:
+            test_ds = LunaClassificationDataset(
+                 test_stride=10,
+                 isTestSet_bool=True,
+                 # series_uid=None,
+                 # sortby_str='random',
+                 # ratio_int=int(self.cli_args.balanced),
+                 # scaled_bool=False,
+                 # multiscaled_bool=False,
+                 # augmented_bool=False,
+                 # noduleInfo_list=None,
+            )
+
+        test_dl = DataLoader(
+            test_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return test_dl
+
+    def initTensorboardWriters(self):
+        if self.trn_writer is None:
+            log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
+
+            self.trn_writer = SummaryWriter(log_dir=log_dir + '_segtrn_' + self.cli_args.comment)
+            self.tst_writer = SummaryWriter(log_dir=log_dir + '_segtst_' + self.cli_args.comment)
+
+
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        train_dl = self.initTrainDl()
+        test_dl = self.initTestDl()
+
+        self.initTensorboardWriters()
+        self.logModelMetrics(self.model)
+
+        best_score = 0.0
+
+        for epoch_ndx in range(1, self.cli_args.epochs + 1):
+            log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
+                epoch_ndx,
+                self.cli_args.epochs,
+                len(train_dl),
+                len(test_dl),
+                self.cli_args.batch_size,
+                (torch.cuda.device_count() if self.use_cuda else 1),
+            ))
+
+            trainingMetrics_tensor = self.doTraining(epoch_ndx, train_dl)
+            if epoch_ndx > 0:
+                self.logPerformanceMetrics(epoch_ndx, 'trn', trainingMetrics_tensor)
+
+            self.logModelMetrics(self.model)
+
+            if self.cli_args.segmentation:
+                self.logImages(epoch_ndx, train_dl, test_dl)
+
+            testingMetrics_tensor = self.doTesting(epoch_ndx, test_dl)
+            score = self.logPerformanceMetrics(epoch_ndx, 'tst', testingMetrics_tensor)
+            best_score = max(score, best_score)
+
+            self.saveModel('seg' if self.cli_args.segmentation else 'cls', epoch_ndx, score == best_score)
+
+        if hasattr(self, 'trn_writer'):
+            self.trn_writer.close()
+            self.tst_writer.close()
+
+    def doTraining(self, epoch_ndx, train_dl):
+        self.model.train()
+        trainingMetrics_tensor = torch.zeros(METRICS_SIZE, len(train_dl.dataset))
+        # train_dl.dataset.shuffleSamples()
+        batch_iter = enumerateWithEstimate(
+            train_dl,
+            "E{} Training".format(epoch_ndx),
+            start_ndx=train_dl.num_workers,
+        )
+        for batch_ndx, batch_tup in batch_iter:
+            self.optimizer.zero_grad()
+
+            if self.cli_args.segmentation:
+                loss_var = self.computeSegmentationLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
+            else:
+                loss_var = self.computeClassificationLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
+
+            if loss_var is not None:
+                loss_var.backward()
+                self.optimizer.step()
+            del loss_var
+
+        self.totalTrainingSamples_count += trainingMetrics_tensor.size(1)
+
+        return trainingMetrics_tensor
+
+    def doTesting(self, epoch_ndx, test_dl):
+        with torch.no_grad():
+            self.model.eval()
+            testingMetrics_tensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset))
+            batch_iter = enumerateWithEstimate(
+                test_dl,
+                "E{} Testing ".format(epoch_ndx),
+                start_ndx=test_dl.num_workers,
+            )
+            for batch_ndx, batch_tup in batch_iter:
+                if self.cli_args.segmentation:
+                    self.computeSegmentationLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
+                else:
+                    self.computeClassificationLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
+
+        return testingMetrics_tensor
+
+    def computeClassificationLoss(self, batch_ndx, batch_tup, batch_size, metrics_tensor):
+        input_tensor, label_tensor, _series_list, _center_list = batch_tup
+
+        input_devtensor = input_tensor.to(self.device)
+        label_devtensor = label_tensor.to(self.device)
+
+        prediction_devtensor = self.model(input_devtensor)
+        loss_devtensor = nn.MSELoss(reduction='none')(prediction_devtensor, label_devtensor)
+
+        start_ndx = batch_ndx * batch_size
+        end_ndx = start_ndx + label_tensor.size(0)
+
+        with torch.no_grad():
+            # log.debug([metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx].shape, label_tensor.shape])
+
+            metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_tensor[:,0]
+            metrics_tensor[METRICS_PRED_NDX, start_ndx:end_ndx] = prediction_devtensor.to('cpu')[:,0]
+            # metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_devtensor.to('cpu')
+
+
+
+
+            prediction_tensor = prediction_devtensor.to('cpu', non_blocking=True)
+            loss_tensor = loss_devtensor.to('cpu', non_blocking=True)[:,0]
+            malLabel_tensor = (label_tensor > 0.5)[:,0]
+            benLabel_tensor = ~malLabel_tensor
+
+
+            malPred_tensor = prediction_tensor > 0.5
+            benPred_tensor = ~malPred_tensor
+            metrics_tensor[METRICS_MTP_NDX, start_ndx:end_ndx] = (malLabel_tensor * malPred_tensor).sum(dim=1)
+            metrics_tensor[METRICS_MFN_NDX, start_ndx:end_ndx] = (malLabel_tensor * benPred_tensor).sum(dim=1)
+            metrics_tensor[METRICS_MFP_NDX, start_ndx:end_ndx] = (benLabel_tensor * malPred_tensor).sum(dim=1)
+
+            metrics_tensor[METRICS_BTP_NDX, start_ndx:end_ndx] = (benLabel_tensor * benPred_tensor).sum(dim=1)
+            metrics_tensor[METRICS_BFN_NDX, start_ndx:end_ndx] = (benLabel_tensor * malPred_tensor).sum(dim=1)
+            metrics_tensor[METRICS_BFP_NDX, start_ndx:end_ndx] = (malLabel_tensor * benPred_tensor).sum(dim=1)
+
+            metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_tensor
+
+            metrics_tensor[METRICS_BEN_LOSS_NDX, start_ndx:end_ndx] = loss_tensor * benLabel_tensor.type(torch.float32)
+            metrics_tensor[METRICS_MAL_LOSS_NDX, start_ndx:end_ndx] = loss_tensor * malLabel_tensor.type(torch.float32)
+
+
+        # TODO: replace with torch.autograd.detect_anomaly
+        # assert np.isfinite(metrics_tensor).all()
+
+        return loss_devtensor.mean()
+
+    def computeSegmentationLoss(self, batch_ndx, batch_tup, batch_size, metrics_tensor):
+        input_tensor, label_tensor, _series_list, _start_list = batch_tup
+
+        # if label_tensor.max() < 0.5:
+        #     return None
+
+        input_devtensor = input_tensor.to(self.device)
+        label_devtensor = label_tensor.to(self.device)
+
+        prediction_devtensor = self.model(input_devtensor)
+
+        # assert prediction_devtensor.is_contiguous()
+
+        start_ndx = batch_ndx * batch_size
+        end_ndx = start_ndx + label_tensor.size(0)
+        max2 = lambda t: t.view(t.size(0), -1).max(dim=1)[0]
+        intersectionSum = lambda a, b: (a * b.to(torch.float32)).view(a.size(0), -1).sum(dim=1)
+
+        diceLoss_devtensor = self.diceLoss(label_devtensor, prediction_devtensor)
+        malLoss_devtensor = self.diceLoss(label_devtensor[:,0], prediction_devtensor[:,0])
+        benLoss_devtensor = self.diceLoss(label_devtensor[:,1], prediction_devtensor[:,1])
+
+        with torch.no_grad():
+            bPred_tensor = prediction_devtensor.to('cpu', non_blocking=True)
+            diceLoss_tensor = diceLoss_devtensor.to('cpu', non_blocking=True)
+            malLoss_tensor = malLoss_devtensor.to('cpu', non_blocking=True)
+            benLoss_tensor = benLoss_devtensor.to('cpu', non_blocking=True)
+
+            # flgLoss_devtensor = self.diceLoss(label_devtensor[:,0], label_devtensor[:,0] * prediction_devtensor[:,1])
+            # flgLoss_tensor = flgLoss_devtensor.to('cpu', non_blocking=True)#.unsqueeze(1)
+
+            metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = max2(label_tensor[:,0]) + max2(label_tensor[:,1]) * 2
+            # metrics_tensor[METRICS_MFOUND_NDX, start_ndx:end_ndx] = (max2(label_tensor[:, 0] * bPred_tensor[:, 1].to(torch.float32)) > 0.5)
+
+            # metrics_tensor[METRICS_MOK_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,0],  bPred_tensor[:,1])
+
+            bPred_tensor = bPred_tensor > 0.5
+            metrics_tensor[METRICS_MTP_NDX, start_ndx:end_ndx] = intersectionSum(    label_tensor[:,0],  bPred_tensor[:,0])
+            metrics_tensor[METRICS_MFN_NDX, start_ndx:end_ndx] = intersectionSum(    label_tensor[:,0], ~bPred_tensor[:,0])
+            metrics_tensor[METRICS_MFP_NDX, start_ndx:end_ndx] = intersectionSum(1 - label_tensor[:,0],  bPred_tensor[:,0])
+
+            metrics_tensor[METRICS_BTP_NDX, start_ndx:end_ndx] = intersectionSum(    label_tensor[:,1],  bPred_tensor[:,1])
+            metrics_tensor[METRICS_BFN_NDX, start_ndx:end_ndx] = intersectionSum(    label_tensor[:,1], ~bPred_tensor[:,1])
+            metrics_tensor[METRICS_BFP_NDX, start_ndx:end_ndx] = intersectionSum(1 - label_tensor[:,1],  bPred_tensor[:,1])
+
+            metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_tensor
+
+            metrics_tensor[METRICS_BEN_LOSS_NDX, start_ndx:end_ndx] = benLoss_tensor
+            metrics_tensor[METRICS_MAL_LOSS_NDX, start_ndx:end_ndx] = malLoss_tensor
+            # metrics_tensor[METRICS_FLG_LOSS_NDX, start_ndx:end_ndx] = flgLoss_tensor
+
+
+            # lungLoss_devtensor = self.diceLoss(label_devtensor[:,2], prediction_devtensor[:,2])
+            # lungLoss_tensor = lungLoss_devtensor.to('cpu').unsqueeze(1)
+            # metrics_tensor[METRICS_LUNG_LOSS_NDX, start_ndx:end_ndx] = lungLoss_tensor
+
+        # TODO: replace with torch.autograd.detect_anomaly
+        # assert np.isfinite(metrics_tensor).all()
+
+        # return nn.MSELoss()(prediction_devtensor, label_devtensor)
+
+        return malLoss_devtensor.mean() + benLoss_devtensor.mean()
+        # return self.diceLoss(label_devtensor[:,0], prediction_devtensor[:,0]).mean()
+
+    def diceLoss(self, label_devtensor, prediction_devtensor, epsilon=0.01, p=False):
+        # sum2 = lambda t: t.sum([1,2,3,4])
+        sum2 = lambda t: t.view(t.size(0), -1).sum(dim=1)
+        # max2 = lambda t: t.view(t.size(0), -1).max(dim=1)[0]
+
+        diceCorrect_devtensor = sum2(prediction_devtensor * label_devtensor)
+        dicePrediction_devtensor = sum2(prediction_devtensor)
+        diceLabel_devtensor = sum2(label_devtensor)
+        epsilon_devtensor = torch.ones_like(diceCorrect_devtensor) * epsilon
+        diceLoss_devtensor = 1 - (2 * diceCorrect_devtensor + epsilon_devtensor) / (dicePrediction_devtensor + diceLabel_devtensor + epsilon_devtensor)
+
+        if not torch.isfinite(diceLoss_devtensor).all():
+            log.debug('')
+            log.debug('diceLoss_devtensor')
+            log.debug(diceLoss_devtensor.to('cpu'))
+            log.debug('diceCorrect_devtensor')
+            log.debug(diceCorrect_devtensor.to('cpu'))
+            log.debug('dicePrediction_devtensor')
+            log.debug(dicePrediction_devtensor.to('cpu'))
+            log.debug('diceLabel_devtensor')
+            log.debug(diceLabel_devtensor.to('cpu'))
+
+        return diceLoss_devtensor
+
+
+
+    def logImages(self, epoch_ndx, train_dl, test_dl):
+        for mode_str, dl in [('trn', train_dl), ('tst', test_dl)]:
+            for i, series_uid in enumerate(sorted(dl.dataset.series_list)[:12]):
+                ct = getCt(series_uid)
+                noduleInfo_tup = (ct.malignantInfo_list or ct.benignInfo_list)[0]
+                center_irc = xyz2irc(noduleInfo_tup.center_xyz, ct.origin_xyz, ct.vxSize_xyz, ct.direction_tup)
+
+                sample_tup = dl.dataset[(series_uid, int(center_irc.index))]
+                input_tensor = sample_tup[0].unsqueeze(0)
+                label_tensor = sample_tup[1].unsqueeze(0)
+
+                input_devtensor = input_tensor.to(self.device)
+                label_devtensor = label_tensor.to(self.device)
+
+                prediction_devtensor = self.model(input_devtensor)
+                prediction_ary = prediction_devtensor.to('cpu').detach().numpy()
+
+                image_ary = np.zeros((512, 512, 3), dtype=np.float32)
+                image_ary[:,:,:] = (input_tensor[0,2].numpy().reshape((512,512,1))) * 0.25
+                image_ary[:,:,0] += prediction_ary[0,0] * 0.5
+                image_ary[:,:,1] += prediction_ary[0,1] * 0.25
+                # image_ary[:,:,2] += prediction_ary[0,2] * 0.5
+
+                # log.debug([image_ary.__array_interface__['typestr']])
+
+                # image_ary = (image_ary * 255).astype(np.uint8)
+
+                # log.debug([image_ary.__array_interface__['typestr']])
+
+                writer = getattr(self, mode_str + '_writer')
+                try:
+                    writer.add_image('{}/{}_pred'.format(mode_str, i), image_ary, self.totalTrainingSamples_count, dataformats='HWC')
+                except:
+                    log.debug([image_ary.shape, image_ary.dtype])
+                    raise
+
+                if epoch_ndx == 1:
+                    label_ary = label_tensor.numpy()
+
+                    image_ary = np.zeros((512, 512, 3), dtype=np.float32)
+                    image_ary[:,:,:] = (input_tensor[0,2].numpy().reshape((512,512,1))) * 0.25
+                    image_ary[:,:,0] += label_ary[0,0] * 0.5
+                    image_ary[:,:,1] += label_ary[0,1] * 0.25
+                    image_ary[:,:,2] += (input_tensor[0,-1].numpy() - (label_ary[0,0].astype(np.bool) | label_ary[0,1].astype(np.bool))) * 0.25
+
+                    # log.debug([image_ary.__array_interface__['typestr']])
+
+                    image_ary = (image_ary * 255).astype(np.uint8)
+
+                    # log.debug([image_ary.__array_interface__['typestr']])
+
+                    writer = getattr(self, mode_str + '_writer')
+                    writer.add_image('{}/{}_label'.format(mode_str, i), image_ary, self.totalTrainingSamples_count, dataformats='HWC')
+
+
+    def logPerformanceMetrics(self,
+                              epoch_ndx,
+                              mode_str,
+                              metrics_tensor,
+                              # trainingMetrics_tensor,
+                              # testingMetrics_tensor,
+                              classificationThreshold_float=0.5,
+                              ):
+        log.info("E{} {}".format(
+            epoch_ndx,
+            type(self).__name__,
+        ))
+
+        score = 0.0
+
+
+        # for mode_str, metrics_tensor in [('trn', trainingMetrics_tensor), ('tst', testingMetrics_tensor)]:
+        metrics_ary = metrics_tensor.cpu().detach().numpy()
+        sum_ary = metrics_ary.sum(axis=1)
+        assert np.isfinite(metrics_ary).all()
+
+        malLabel_mask = (metrics_ary[METRICS_LABEL_NDX] == 1) | (metrics_ary[METRICS_LABEL_NDX] == 3)
+
+        if self.cli_args.segmentation:
+            benLabel_mask = (metrics_ary[METRICS_LABEL_NDX] == 2) | (metrics_ary[METRICS_LABEL_NDX] == 3)
+        else:
+            benLabel_mask = ~malLabel_mask
+        # malFound_mask = metrics_ary[METRICS_MFOUND_NDX] > classificationThreshold_float
+
+        # malLabel_mask = ~benLabel_mask
+        # malPred_mask = ~benPred_mask
+
+        benLabel_count = sum_ary[METRICS_BTP_NDX] + sum_ary[METRICS_BFN_NDX]
+        malLabel_count = sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]
+
+        trueNeg_count = benCorrect_count = sum_ary[METRICS_BTP_NDX]
+        truePos_count = malCorrect_count = sum_ary[METRICS_MTP_NDX]
+#
+#             falsePos_count = benLabel_count - benCorrect_count
+#             falseNeg_count = malLabel_count - malCorrect_count
+
+
+        metrics_dict = {}
+        metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
+        # metrics_dict['loss/msk'] = metrics_ary[METRICS_MASKLOSS_NDX].mean()
+        # metrics_dict['loss/mal'] = metrics_ary[METRICS_MALLOSS_NDX].mean()
+        # metrics_dict['loss/lng'] = metrics_ary[METRICS_LUNG_LOSS_NDX, benLabel_mask].mean()
+        metrics_dict['loss/mal'] = np.nan_to_num(metrics_ary[METRICS_MAL_LOSS_NDX, malLabel_mask].mean())
+        metrics_dict['loss/ben'] = metrics_ary[METRICS_BEN_LOSS_NDX, benLabel_mask].mean()
+        # metrics_dict['loss/flg'] = metrics_ary[METRICS_FLG_LOSS_NDX].mean()
+
+        # metrics_dict['flagged/all'] = sum_ary[METRICS_MOK_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
+        # metrics_dict['flagged/slices'] = (malLabel_mask & malFound_mask).sum() / malLabel_mask.sum() * 100
+
+        metrics_dict['correct/mal'] = sum_ary[METRICS_MTP_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
+        metrics_dict['correct/ben'] = sum_ary[METRICS_BTP_NDX] / (sum_ary[METRICS_BTP_NDX] + sum_ary[METRICS_BFN_NDX]) * 100
+
+        precision = metrics_dict['pr/precision'] = sum_ary[METRICS_MTP_NDX] / ((sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFP_NDX]) or 1)
+        recall    = metrics_dict['pr/recall']    = sum_ary[METRICS_MTP_NDX] / ((sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) or 1)
+
+        metrics_dict['pr/f1_score'] = 2 * (precision * recall) / ((precision + recall) or 1)
+
+        log.info(("E{} {:8} "
+                 + "{loss/all:.4f} loss, "
+                 # + "{loss/flg:.4f} flagged loss, "
+                 # + "{flagged/all:-5.1f}% pixels flagged, "
+                 # + "{flagged/slices:-5.1f}% slices flagged, "
+                 + "{pr/precision:.4f} precision, "
+                 + "{pr/recall:.4f} recall, "
+                 + "{pr/f1_score:.4f} f1 score"
+                  ).format(
+            epoch_ndx,
+            mode_str,
+            **metrics_dict,
+        ))
+        log.info(("E{} {:8} "
+                 + "{loss/mal:.4f} loss, "
+                 + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
+        ).format(
+            epoch_ndx,
+            mode_str + '_mal',
+            malCorrect_count=malCorrect_count,
+            malLabel_count=malLabel_count,
+            **metrics_dict,
+        ))
+        log.info(("E{} {:8} "
+                 + "{loss/ben:.4f} loss, "
+                 + "{correct/ben:-5.1f}% correct ({benCorrect_count:} of {benLabel_count:})"
+        ).format(
+            epoch_ndx,
+            mode_str + '_ben',
+            benCorrect_count=benCorrect_count,
+            benLabel_count=benLabel_count,
+            **metrics_dict,
+        ))
+
+        writer = getattr(self, mode_str + '_writer')
+
+        for key, value in metrics_dict.items():
+            writer.add_scalar('seg_' + key, value, self.totalTrainingSamples_count)
+
+            if not self.cli_args.segmentation:
+                writer.add_pr_curve(
+                    'pr',
+                    metrics_ary[METRICS_LABEL_NDX],
+                    metrics_ary[METRICS_PRED_NDX],
+                    self.totalTrainingSamples_count,
+                )
+
+                benHist_mask = benLabel_mask & (metrics_ary[METRICS_PRED_NDX] > 0.01)
+                malHist_mask = malLabel_mask & (metrics_ary[METRICS_PRED_NDX] < 0.99)
+
+                bins = [x/50.0 for x in range(51)]
+                writer.add_histogram(
+                    'is_ben',
+                    metrics_ary[METRICS_PRED_NDX, benHist_mask],
+                    self.totalTrainingSamples_count,
+                    bins=bins,
+                )
+                writer.add_histogram(
+                    'is_mal',
+                    metrics_ary[METRICS_PRED_NDX, malHist_mask],
+                    self.totalTrainingSamples_count,
+                    bins=bins,
+                )
+
+        score = 1 \
+            + metrics_dict['pr/f1_score'] \
+            - metrics_dict['loss/mal'] * 0.01 \
+            - metrics_dict['loss/all'] * 0.0001
+
+        return score
+
+    def logModelMetrics(self, model):
+        writer = getattr(self, 'trn_writer')
+
+        model = getattr(model, 'module', model)
+
+        for name, param in model.named_parameters():
+            if param.requires_grad:
+                min_data = float(param.data.min())
+                max_data = float(param.data.max())
+                max_extent = max(abs(min_data), abs(max_data))
+
+                bins = [x/50*max_extent for x in range(-50, 51)]
+
+                writer.add_histogram(
+                    name.rsplit('.', 1)[-1] + '/' + name,
+                    param.data.cpu().numpy(),
+                    # metrics_ary[METRICS_PRED_NDX, benHist_mask],
+                    self.totalTrainingSamples_count,
+                    bins=bins,
+                )
+
+                # print name, param.data
+
+    def saveModel(self, type_str, epoch_ndx, isBest=False):
+        file_path = os.path.join('data-unversioned', 'models', self.cli_args.tb_prefix, '{}_{}_{}.{}.state'.format(type_str, self.time_str, self.cli_args.comment, self.totalTrainingSamples_count))
+
+        os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
+
+        model = self.model
+        if hasattr(model, 'module'):
+            model = model.module
+
+        state = {
+            'model_state': model.state_dict(),
+            'model_name': type(model).__name__,
+            'optimizer_state' : self.optimizer.state_dict(),
+            'optimizer_name': type(self.optimizer).__name__,
+            'epoch': epoch_ndx,
+            'totalTrainingSamples_count': self.totalTrainingSamples_count,
+            # 'resumed_from': self.cli_args.resume,
+        }
+        torch.save(state, file_path)
+
+        log.debug("Saved model params to {}".format(file_path))
+
+        if isBest:
+            file_path = os.path.join('data-unversioned', 'models', self.cli_args.tb_prefix, '{}_{}_{}.{}.state'.format(type_str, self.time_str, self.cli_args.comment, 'best'))
+            torch.save(state, file_path)
+
+            log.debug("Saved model params to {}".format(file_path))
+
+
+if __name__ == '__main__':
+    sys.exit(LunaTrainingApp().main() or 0)

+ 99 - 0
p2ch13/vis.py

@@ -0,0 +1,99 @@
+import matplotlib
+matplotlib.use('nbagg')
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+from p2ch13.dsets import Ct, LunaDataset
+
+clim=(-1000.0, 300)
+
+def findMalignantSamples(start_ndx=0, limit=100):
+    ds = LunaDataset(sortby_str='malignancy_size')
+
+    malignantSample_list = []
+    for sample_tup in ds.noduleInfo_list:
+        if sample_tup.isMalignant_bool:
+            print(len(malignantSample_list), sample_tup)
+            malignantSample_list.append(sample_tup)
+
+        if len(malignantSample_list) >= limit:
+            break
+
+    return malignantSample_list
+
+def showNodule(series_uid, batch_ndx=None, **kwargs):
+    ds = LunaDataset(series_uid=series_uid, **kwargs)
+    malignant_list = [i for i, x in enumerate(ds.noduleInfo_list) if x.isMalignant_bool]
+
+    if batch_ndx is None:
+        if malignant_list:
+            batch_ndx = malignant_list[0]
+        else:
+            print("Warning: no malignant samples found; using first non-malignant sample.")
+            batch_ndx = 0
+
+    ct = Ct(series_uid)
+    ct_t, malignant_t, series_uid, center_irc = ds[batch_ndx]
+    ct_a = ct_t[0].numpy()
+
+    fig = plt.figure(figsize=(30, 50))
+
+    group_list = [
+        [9, 11, 13],
+        [15, 16, 17],
+        [19, 21, 23],
+    ]
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 1)
+    subplot.set_title('index {}'.format(int(center_irc.index)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct.hu_a[int(center_irc.index)], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 2)
+    subplot.set_title('row {}'.format(int(center_irc.row)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct.hu_a[:,int(center_irc.row)], clim=clim, cmap='gray')
+    plt.gca().invert_yaxis()
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 3)
+    subplot.set_title('col {}'.format(int(center_irc.col)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct.hu_a[:,:,int(center_irc.col)], clim=clim, cmap='gray')
+    plt.gca().invert_yaxis()
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 4)
+    subplot.set_title('index {}'.format(int(center_irc.index)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct_a[ct_a.shape[0]//2], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 5)
+    subplot.set_title('row {}'.format(int(center_irc.row)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct_a[:,ct_a.shape[1]//2], clim=clim, cmap='gray')
+    plt.gca().invert_yaxis()
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 6)
+    subplot.set_title('col {}'.format(int(center_irc.col)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct_a[:,:,ct_a.shape[2]//2], clim=clim, cmap='gray')
+    plt.gca().invert_yaxis()
+
+    for row, index_list in enumerate(group_list):
+        for col, index in enumerate(index_list):
+            subplot = fig.add_subplot(len(group_list) + 2, 3, row * 3 + col + 7)
+            subplot.set_title('slice {}'.format(index), fontsize=30)
+            for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+                label.set_fontsize(20)
+            plt.imshow(ct_a[index], clim=clim, cmap='gray')
+
+
+    print(series_uid, batch_ndx, bool(malignant_t[0]), malignant_list)
+
+

File diff suppressed because it is too large
+ 79 - 0
p2ch13_explore_data.ipynb


File diff suppressed because it is too large
+ 113 - 0
p2ch13_explore_diagnose.ipynb


+ 0 - 0
p2ch14/__init__.py


+ 372 - 0
p2ch14/diagnose.py

@@ -0,0 +1,372 @@
+import argparse
+import glob
+import os
+import sys
+
+import numpy as np
+import scipy.ndimage.measurements as measure
+import scipy.ndimage.morphology as morph
+
+import torch
+import torch.nn as nn
+import torch.optim
+
+from torch.utils.data import DataLoader
+
+from util.util import enumerateWithEstimate
+from .dsets import LunaDataset, Luna2dSegmentationDataset, getCt, getNoduleInfoList, NoduleInfoTuple
+from p2ch13.model_seg import UNetWrapper
+from p2ch13.model_cls import LunaModel, AlternateLunaModel
+
+from util.logconf import logging
+from util.util import xyz2irc, irc2xyz
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+
+class LunaDiagnoseApp(object):
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            log.debug(sys.argv)
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=4,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+
+        parser.add_argument('--series-uid',
+            help='Limit inference to this Series UID only.',
+            default=None,
+            type=str,
+        )
+
+        parser.add_argument('--include-train',
+            help="Include data that was in the training set. (default: validation data only)",
+            action='store_true',
+            default=False,
+        )
+
+        parser.add_argument('--segmentation-path',
+            help="Path to the saved segmentation model",
+            nargs='?',
+            default=None,
+        )
+
+        parser.add_argument('--classification-path',
+            help="Path to the saved classification model",
+            nargs='?',
+            default=None,
+        )
+
+        parser.add_argument('--tb-prefix',
+            default='p2ch13',
+            help="Data prefix to use for Tensorboard run. Defaults to chapter.",
+        )
+
+        self.cli_args = parser.parse_args(sys_argv)
+        # self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
+
+        self.use_cuda = torch.cuda.is_available()
+        self.device = torch.device("cuda" if self.use_cuda else "cpu")
+
+        if not self.cli_args.segmentation_path:
+            self.cli_args.segmentation_path = self.initModelPath('seg')
+
+        if not self.cli_args.classification_path:
+            self.cli_args.classification_path = self.initModelPath('cls')
+
+        self.seg_model, self.cls_model = self.initModels()
+
+    def initModelPath(self, type_str):
+        local_path = os.path.join(
+            'data-unversioned',
+            'part2',
+            'models',
+            self.cli_args.tb_prefix,
+            type_str + '_{}_{}.{}.state'.format('*', '*', 'best'),
+        )
+
+        file_list = glob.glob(local_path)
+        if not file_list:
+            pretrained_path = os.path.join(
+                'data',
+                'part2',
+                'models',
+                type_str + '_{}_{}.{}.state'.format('*', '*', '*'),
+            )
+            file_list = glob.glob(pretrained_path)
+        else:
+            pretrained_path = None
+
+        file_list.sort()
+
+        try:
+            return file_list[-1]
+        except IndexError:
+            log.debug([local_path, pretrained_path, file_list])
+            raise
+
+    def initModels(self):
+        log.debug(self.cli_args.segmentation_path)
+        seg_dict = torch.load(self.cli_args.segmentation_path)
+
+        seg_model = UNetWrapper(in_channels=8, n_classes=1, depth=4, wf=3, padding=True, batch_norm=True, up_mode='upconv')
+        seg_model.load_state_dict(seg_dict['model_state'])
+        seg_model.eval()
+
+        log.debug(self.cli_args.classification_path)
+        cls_dict = torch.load(self.cli_args.classification_path)
+
+        cls_model = LunaModel()
+        # cls_model = AlternateLunaModel()
+        cls_model.load_state_dict(cls_dict['model_state'])
+        cls_model.eval()
+
+        if self.use_cuda:
+            if torch.cuda.device_count() > 1:
+                seg_model = nn.DataParallel(seg_model)
+                cls_model = nn.DataParallel(cls_model)
+
+            seg_model = seg_model.to(self.device)
+            cls_model = cls_model.to(self.device)
+
+        return seg_model, cls_model
+
+
+    def initSegmentationDl(self, series_uid):
+        seg_ds = Luna2dSegmentationDataset(
+                contextSlices_count=3,
+                series_uid=series_uid,
+                fullCt_bool=True,
+            )
+        seg_dl = DataLoader(
+            seg_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return seg_dl
+
+    def initClassificationDl(self, noduleInfo_list):
+        cls_ds = LunaDataset(
+                sortby_str='series_uid',
+                noduleInfo_list=noduleInfo_list,
+            )
+        cls_dl = DataLoader(
+            cls_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return cls_dl
+
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        val_ds = LunaDataset(
+            val_stride=10,
+            isValSet_bool=True,
+        )
+        val_set = set(
+            noduleInfo_tup.series_uid
+            for noduleInfo_tup in val_ds.noduleInfo_list
+        )
+        malignant_set = set(
+            noduleInfo_tup.series_uid
+            for noduleInfo_tup in getNoduleInfoList()
+            if noduleInfo_tup.isMalignant_bool
+        )
+
+        if self.cli_args.series_uid:
+            series_set = set(self.cli_args.series_uid.split(','))
+        else:
+            series_set = set(
+                noduleInfo_tup.series_uid
+                for noduleInfo_tup in getNoduleInfoList()
+            )
+
+        train_list = sorted(series_set - val_set) if self.cli_args.include_train else []
+        val_list = sorted(series_set & val_set)
+
+
+        noduleInfo_list = []
+        series_iter = enumerateWithEstimate(
+            val_list + train_list,
+            "Series",
+        )
+        for _series_ndx, series_uid in series_iter:
+            ct, output_a, _mask_a, clean_a = self.segmentCt(series_uid)
+
+            noduleInfo_list += self.clusterSegmentationOutput(
+                series_uid,
+                ct,
+                clean_a,
+            )
+
+            # if _series_ndx > 10:
+            #     break
+
+
+        cls_dl = self.initClassificationDl(noduleInfo_list)
+
+        series2diagnosis_dict = {}
+        batch_iter = enumerateWithEstimate(
+            cls_dl,
+            "Cls all",
+            start_ndx=cls_dl.num_workers,
+        )
+        for batch_ndx, batch_tup in batch_iter:
+            input_t, _, series_list, center_list = batch_tup
+
+            input_g = input_t.to(self.device)
+            with torch.no_grad():
+                _logits_g, probability_g = self.cls_model(input_g)
+
+            classifications_list = zip(
+                series_list,
+                center_list,
+                probability_g[:,1].to('cpu'),
+            )
+
+            for cls_tup in classifications_list:
+                series_uid, center_irc, probablity_t = cls_tup
+                probablity_float = probablity_t.item()
+
+                this_tup = (probablity_float, tuple(center_irc))
+                current_tup = series2diagnosis_dict.get(series_uid, this_tup)
+                try:
+                    assert np.all(np.isfinite(tuple(center_irc)))
+                    if this_tup > current_tup:
+                        log.debug([series_uid, this_tup])
+                    series2diagnosis_dict[series_uid] = max(this_tup, current_tup)
+                except:
+                    log.debug([(type(x), x) for x in this_tup] + [(type(x), x) for x in current_tup])
+                    raise
+
+        log.info('Training set:')
+        self.logResults('Training', train_list, series2diagnosis_dict, malignant_set)
+
+        log.info('Validation set:')
+        self.logResults('Validation', val_list, series2diagnosis_dict, malignant_set)
+
+    def segmentCt(self, series_uid):
+        with torch.no_grad():
+            ct = getCt(series_uid)
+
+            output_a = np.zeros_like(ct.hu_a, dtype=np.float32)
+
+            seg_dl = self.initSegmentationDl(series_uid)
+            for batch_tup in seg_dl:
+                input_t = batch_tup[0]
+                ndx_list = batch_tup[6]
+
+                input_g = input_t.to(self.device)
+                prediction_g = self.seg_model(input_g)
+
+                for i, sample_ndx in enumerate(ndx_list):
+                    output_a[sample_ndx] = prediction_g[i].cpu().numpy()
+
+            mask_a = output_a > 0.5
+            clean_a = morph.binary_erosion(mask_a, iterations=1)
+            clean_a = morph.binary_dilation(clean_a, iterations=2)
+
+        return ct, output_a, mask_a, clean_a
+
+    def clusterSegmentationOutput(self, series_uid,  ct, clean_a):
+        noduleLabel_a, nodule_count = measure.label(clean_a)
+        centerIrc_list = measure.center_of_mass(
+            ct.hu_a + 1001,
+            labels=noduleLabel_a,
+            index=list(range(1, nodule_count+1)),
+        )
+
+        # n = 1298
+        # log.debug([
+        #     (noduleLabel_a == n).sum(),
+        #     np.where(noduleLabel_a == n),
+        #
+        #     ct.hu_a[noduleLabel_a == n].sum(),
+        #     (ct.hu_a + 1000)[noduleLabel_a == n].sum(),
+        # ])
+
+        # if nodule_count == 1:
+        #     centerIrc_list = [centerIrc_list]
+
+        noduleInfo_list = []
+        for i, center_irc in enumerate(centerIrc_list):
+            center_xyz = irc2xyz(
+                center_irc,
+                ct.origin_xyz,
+                ct.vxSize_xyz,
+                ct.direction_tup,
+            )
+            assert np.all(np.isfinite(center_irc)), repr(['irc', center_irc, i, nodule_count])
+            assert np.all(np.isfinite(center_xyz)), repr(['xyz', center_xyz])
+            noduleInfo_tup = \
+                NoduleInfoTuple(False, 0.0, series_uid, center_xyz)
+            noduleInfo_list.append(noduleInfo_tup)
+
+        return noduleInfo_list
+
+    def logResults(self, mode_str, filtered_list, series2diagnosis_dict, malignant_set):
+        count_dict = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}
+        for series_uid in filtered_list:
+            probablity_float, center_irc = series2diagnosis_dict.get(series_uid, (0.0, None))
+            if center_irc is not None:
+                center_irc = tuple(int(x.item()) for x in center_irc)
+            malignant_bool = series_uid in malignant_set
+            prediction_bool = probablity_float > 0.5
+            correct_bool = malignant_bool == prediction_bool
+
+            if malignant_bool and prediction_bool:
+                count_dict['tp'] += 1
+            if not malignant_bool and not prediction_bool:
+                count_dict['tn'] += 1
+            if not malignant_bool and prediction_bool:
+                count_dict['fp'] += 1
+            if malignant_bool and not prediction_bool:
+                count_dict['fn'] += 1
+
+
+            log.info("{} {} Mal:{!r:5} Pred:{!r:5} Correct?:{!r:5} Value:{:.4f} {}".format(
+                mode_str,
+                series_uid,
+                malignant_bool,
+                prediction_bool,
+                correct_bool,
+                probablity_float,
+                center_irc,
+            ))
+
+        total_count = sum(count_dict.values())
+        percent_dict = {k: v / (total_count or 1) * 100 for k, v in count_dict.items()}
+
+        precision = percent_dict['p'] = count_dict['tp'] / ((count_dict['tp'] + count_dict['fp']) or 1)
+        recall    = percent_dict['r'] = count_dict['tp'] / ((count_dict['tp'] + count_dict['fn']) or 1)
+        percent_dict['f1'] = 2 * (precision * recall) / ((precision + recall) or 1)
+
+        log.info(mode_str + " tp:{tp:.1f}%, tn:{tn:.1f}%, fp:{fp:.1f}%, fn:{fn:.1f}%".format(
+            **percent_dict,
+        ))
+        log.info(mode_str + " precision:{p:.3f}, recall:{r:.3f}, F1:{f1:.3f}".format(
+            **percent_dict,
+        ))
+
+
+
+if __name__ == '__main__':
+    LunaTrainingApp().main()

+ 580 - 0
p2ch14/dsets.py

@@ -0,0 +1,580 @@
+import copy
+import csv
+import functools
+import glob
+import math
+import os
+import random
+
+from collections import namedtuple
+
+import SimpleITK as sitk
+
+import numpy as np
+import scipy.ndimage.morphology as morph
+
+import torch
+import torch.cuda
+from torch.utils.data import Dataset
+import torch.nn as nn
+import torch.nn.functional as F
+
+from util.disk import getCache
+from util.util import XyzTuple, xyz2irc
+from util.logconf import logging
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+raw_cache = getCache('part2ch13_raw')
+
+NoduleInfoTuple = namedtuple('NoduleInfoTuple', 'isMalignant_bool, diameter_mm, series_uid, center_xyz')
+MaskTuple = namedtuple('MaskTuple', 'raw_dense_mask, dense_mask, body_mask, air_mask, raw_nodule_mask, nodule_mask, lung_mask, ben_mask, mal_mask')
+
+@functools.lru_cache(1)
+def getNoduleInfoList(requireDataOnDisk_bool=True):
+    # We construct a set with all series_uids that are present on disk.
+    # This will let us use the data, even if we haven't downloaded all of
+    # the subsets yet.
+    mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
+    dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
+
+    diameter_dict = {}
+    with open('data/part2/luna/annotations.csv', "r") as f:
+        for row in list(csv.reader(f))[1:]:
+            series_uid = row[0]
+            annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
+            annotationDiameter_mm = float(row[4])
+
+            diameter_dict.setdefault(series_uid, []).append((annotationCenter_xyz, annotationDiameter_mm))
+
+    noduleInfo_list = []
+    with open('data/part2/luna/candidates.csv', "r") as f:
+        for row in list(csv.reader(f))[1:]:
+            series_uid = row[0]
+
+            if series_uid not in dataPresentOnDisk_set and requireDataOnDisk_bool:
+                continue
+
+            isMalignant_bool = bool(int(row[4]))
+            candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
+
+            candidateDiameter_mm = 0.0
+            for annotationCenter_xyz, annotationDiameter_mm in diameter_dict.get(series_uid, []):
+                for i in range(3):
+                    delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
+                    if delta_mm > annotationDiameter_mm / 4:
+                        break
+                else:
+                    candidateDiameter_mm = annotationDiameter_mm
+                    break
+
+            noduleInfo_list.append(NoduleInfoTuple(isMalignant_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))
+
+    noduleInfo_list.sort(reverse=True)
+    return noduleInfo_list
+
+class Ct(object):
+    def __init__(self, series_uid, buildMasks_bool=True):
+        mhd_path = glob.glob('data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
+
+        ct_mhd = sitk.ReadImage(mhd_path)
+        ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
+
+        # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
+        # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
+        # This gets rid of negative density stuff used to indicate out-of-FOV
+        ct_a[ct_a < -1000] = -1000
+
+        # This nukes any weird hotspots and clamps bone down
+        ct_a[ct_a > 1000] = 1000
+
+        self.series_uid = series_uid
+        self.hu_a = ct_a
+
+        self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
+        self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
+        self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
+
+        noduleInfo_list = getNoduleInfoList()
+        self.benignInfo_list = [ni_tup
+                                for ni_tup in noduleInfo_list
+                                    if not ni_tup.isMalignant_bool
+                                        and ni_tup.series_uid == self.series_uid]
+        self.benign_mask = self.buildAnnotationMask(self.benignInfo_list)[0]
+        self.benign_indexes = sorted(set(self.benign_mask.nonzero()[0]))
+
+        self.malignantInfo_list = [ni_tup
+                                   for ni_tup in noduleInfo_list
+                                        if ni_tup.isMalignant_bool
+                                            and ni_tup.series_uid == self.series_uid]
+        self.malignant_mask = self.buildAnnotationMask(self.malignantInfo_list)[0]
+        self.malignant_indexes = sorted(set(self.malignant_mask.nonzero()[0]))
+
+    def buildAnnotationMask(self, noduleInfo_list, threshold_hu = -500):
+        boundingBox_a = np.zeros_like(self.hu_a, dtype=np.bool)
+
+        for noduleInfo_tup in noduleInfo_list:
+            center_irc = xyz2irc(
+                noduleInfo_tup.center_xyz,
+                self.origin_xyz,
+                self.vxSize_xyz,
+                self.direction_tup,
+            )
+            ci = int(center_irc.index)
+            cr = int(center_irc.row)
+            cc = int(center_irc.col)
+
+            index_radius = 2
+            try:
+                while self.hu_a[ci + index_radius, cr, cc] > threshold_hu and \
+                        self.hu_a[ci - index_radius, cr, cc] > threshold_hu:
+                    index_radius += 1
+            except IndexError:
+                index_radius -= 1
+
+            row_radius = 2
+            try:
+                while self.hu_a[ci, cr + row_radius, cc] > threshold_hu and \
+                        self.hu_a[ci, cr - row_radius, cc] > threshold_hu:
+                    row_radius += 1
+            except IndexError:
+                row_radius -= 1
+
+            col_radius = 2
+            try:
+                while self.hu_a[ci, cr, cc + col_radius] > threshold_hu and \
+                        self.hu_a[ci, cr, cc - col_radius] > threshold_hu:
+                    col_radius += 1
+            except IndexError:
+                col_radius -= 1
+
+            # assert index_radius > 0, repr([noduleInfo_tup.center_xyz, center_irc, self.hu_a[ci, cr, cc]])
+            # assert row_radius > 0
+            # assert col_radius > 0
+
+
+            slice_tup = (
+                slice(ci - index_radius, ci + index_radius + 1),
+                slice(cr - row_radius, cr + row_radius + 1),
+                slice(cc - col_radius, cc + row_radius + 1),
+            )
+            boundingBox_a[slice_tup] = True
+
+        thresholded_a = boundingBox_a & (self.hu_a > threshold_hu)
+        mask_a = morph.binary_dilation(thresholded_a, iterations=2)
+
+        return mask_a, thresholded_a, boundingBox_a
+
+    def build2dLungMask(self, mask_ndx):
+        raw_dense_mask = self.hu_a[mask_ndx] > -300
+        dense_mask = morph.binary_closing(raw_dense_mask, iterations=2)
+        dense_mask = morph.binary_opening(dense_mask, iterations=2)
+
+        body_mask = morph.binary_fill_holes(dense_mask)
+        air_mask = morph.binary_fill_holes(body_mask & ~dense_mask)
+        air_mask = morph.binary_erosion(air_mask, iterations=1)
+
+        lung_mask = morph.binary_dilation(air_mask, iterations=5)
+
+        raw_nodule_mask = self.hu_a[mask_ndx] > -600
+        raw_nodule_mask &= air_mask
+        nodule_mask = morph.binary_opening(raw_nodule_mask, iterations=1)
+
+        ben_mask = morph.binary_dilation(nodule_mask, iterations=1)
+        ben_mask &= ~self.malignant_mask[mask_ndx]
+
+        mal_mask = self.malignant_mask[mask_ndx]
+
+        return MaskTuple(
+            raw_dense_mask,
+            dense_mask,
+            body_mask,
+            air_mask,
+            raw_nodule_mask,
+            nodule_mask,
+            lung_mask,
+            ben_mask,
+            mal_mask,
+        )
+
+    # def build3dLungMask(self):
+    #     air_mask, lung_mask, dense_mask, denoise_mask, body_mask, ben_mask, mal_mask = mask_list = \
+    #         [np.zeros_like(self.hu_a, dtype=np.bool) for _ in range(6)]
+    #
+    #     for mask_ndx in range(self.hu_a.shape[0]):
+    #         for i, mask_a in enumerate(self.build2dLungMask(mask_ndx)):
+    #             mask_list[i][mask_ndx] = mask_a
+    #
+    #     return MaskTuple(air_mask, lung_mask, dense_mask, denoise_mask, body_mask, ben_mask, mal_mask)
+
+
+    def getRawNodule(self, center_xyz, width_irc):
+        center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
+
+        slice_list = []
+        for axis, center_val in enumerate(center_irc):
+            try:
+                start_ndx = int(round(center_val - width_irc[axis]/2))
+            except:
+                log.debug([center_val, width_irc, center_xyz, center_irc])
+                raise
+            end_ndx = int(start_ndx + width_irc[axis])
+
+            assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
+
+            if start_ndx < 0:
+                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
+                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
+                start_ndx = 0
+                end_ndx = int(width_irc[axis])
+
+            if end_ndx > self.hu_a.shape[axis]:
+                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
+                #     self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
+                end_ndx = self.hu_a.shape[axis]
+                start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])
+
+            slice_list.append(slice(start_ndx, end_ndx))
+
+        ct_chunk = self.hu_a[tuple(slice_list)]
+
+        return ct_chunk, center_irc
+
+ctCache_depth = 5
+@functools.lru_cache(ctCache_depth, typed=True)
+def getCt(series_uid):
+    return Ct(series_uid)
+
+@raw_cache.memoize(typed=True)
+def getCtRawNodule(series_uid, center_xyz, width_irc):
+    ct = getCt(series_uid)
+    ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
+    return ct_chunk, center_irc
+
+@raw_cache.memoize(typed=True)
+def getCtSampleSize(series_uid):
+    ct = Ct(series_uid, buildMasks_bool=False)
+    return len(ct.benign_indexes)
+
+def getCtAugmentedNodule(
+        augmentation_dict,
+        series_uid, center_xyz, width_irc,
+        use_cache=True):
+    if use_cache:
+        ct_chunk, center_irc = getCtRawNodule(series_uid, center_xyz, width_irc)
+    else:
+        ct = getCt(series_uid)
+        ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
+
+    ct_t = torch.tensor(ct_chunk).unsqueeze(0).unsqueeze(0).to(torch.float32)
+
+    transform_t = torch.eye(4).to(torch.float64)
+
+    for i in range(3):
+        if 'flip' in augmentation_dict:
+            if random.random() > 0.5:
+                transform_t[i,i] *= -1
+
+        if 'offset' in augmentation_dict:
+            offset_float = augmentation_dict['offset']
+            random_float = (random.random() * 2 - 1)
+            transform_t[3,i] = offset_float * random_float
+
+        if 'scale' in augmentation_dict:
+            scale_float = augmentation_dict['scale']
+            random_float = (random.random() * 2 - 1)
+            transform_t[i,i] *= 1.0 + scale_float * random_float
+
+    if 'rotate' in augmentation_dict:
+        angle_rad = random.random() * math.pi * 2
+        s = math.sin(angle_rad)
+        c = math.cos(angle_rad)
+
+        rotation_t = torch.tensor([
+            [c, -s, 0, 0],
+            [s, c, 0, 0],
+            [0, 0, 1, 0],
+            [0, 0, 0, 1],
+        ], dtype=torch.float64)
+
+        transform_t @= rotation_t
+
+    affine_t = F.affine_grid(
+            transform_t[:3].unsqueeze(0).to(torch.float32),
+            ct_t.size(),
+        )
+
+    augmented_chunk = F.grid_sample(
+            ct_t,
+            affine_t,
+            padding_mode='border'
+        ).to('cpu')
+
+    if 'noise' in augmentation_dict:
+        noise_t = torch.randn_like(augmented_chunk)
+        noise_t *= augmentation_dict['noise']
+
+        augmented_chunk += noise_t
+
+    return augmented_chunk[0], center_irc
+
+
+class LunaDataset(Dataset):
+    def __init__(self,
+                 val_stride=0,
+                 isValSet_bool=None,
+                 series_uid=None,
+                 sortby_str='random',
+                 ratio_int=0,
+                 augmentation_dict=None,
+                 noduleInfo_list=None,
+            ):
+        self.ratio_int = ratio_int
+        self.augmentation_dict = augmentation_dict
+
+        if noduleInfo_list:
+            self.noduleInfo_list = copy.copy(noduleInfo_list)
+            self.use_cache = False
+        else:
+            self.noduleInfo_list = copy.copy(getNoduleInfoList())
+            self.use_cache = True
+
+        if series_uid:
+            self.series_list = [series_uid]
+        else:
+            self.series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
+
+        if isValSet_bool:
+            assert val_stride > 0, val_stride
+            self.series_list = self.series_list[::val_stride]
+            assert self.series_list
+        elif val_stride > 0:
+            del self.series_list[::val_stride]
+            assert self.series_list
+
+        series_set = set(self.series_list)
+        self.noduleInfo_list = [x for x in self.noduleInfo_list if x.series_uid in series_set]
+
+        if sortby_str == 'random':
+            random.shuffle(self.noduleInfo_list)
+        elif sortby_str == 'series_uid':
+            self.noduleInfo_list.sort(key=lambda x: (x[2], x[3])) # sorting by series_uid, center_xyz)
+        elif sortby_str == 'malignancy_size':
+            pass
+        else:
+            raise Exception("Unknown sort: " + repr(sortby_str))
+
+        self.benign_list = [nt for nt in self.noduleInfo_list if not nt.isMalignant_bool]
+        self.malignant_list = [nt for nt in self.noduleInfo_list if nt.isMalignant_bool]
+
+        log.info("{!r}: {} {} samples, {} ben, {} mal, {} ratio".format(
+            self,
+            len(self.noduleInfo_list),
+            "validation" if isValSet_bool else "training",
+            len(self.benign_list),
+            len(self.malignant_list),
+            '{}:1'.format(self.ratio_int) if self.ratio_int else 'unbalanced'
+        ))
+
+    def shuffleSamples(self):
+        if self.ratio_int:
+            random.shuffle(self.benign_list)
+            random.shuffle(self.malignant_list)
+
+    def __len__(self):
+        if self.ratio_int:
+            # return 20000
+            return 200000
+        else:
+            return len(self.noduleInfo_list)
+
+    def __getitem__(self, ndx):
+        if self.ratio_int:
+            malignant_ndx = ndx // (self.ratio_int + 1)
+
+            if ndx % (self.ratio_int + 1):
+                benign_ndx = ndx - 1 - malignant_ndx
+                nodule_tup = self.benign_list[benign_ndx % len(self.benign_list)]
+            else:
+                nodule_tup = self.malignant_list[malignant_ndx % len(self.malignant_list)]
+        else:
+            nodule_tup = self.noduleInfo_list[ndx]
+
+        width_irc = (32, 48, 48)
+
+        if self.augmentation_dict:
+            nodule_t, center_irc = getCtAugmentedNodule(
+                self.augmentation_dict,
+                nodule_tup.series_uid,
+                nodule_tup.center_xyz,
+                width_irc,
+                self.use_cache,
+            )
+        elif self.use_cache:
+            nodule_a, center_irc = getCtRawNodule(
+                nodule_tup.series_uid,
+                nodule_tup.center_xyz,
+                width_irc,
+            )
+            nodule_t = torch.from_numpy(nodule_a).to(torch.float32)
+            nodule_t = nodule_t.unsqueeze(0)
+        else:
+            ct = getCt(nodule_tup.series_uid)
+            nodule_a, center_irc = ct.getRawNodule(
+                nodule_tup.center_xyz,
+                width_irc,
+            )
+            nodule_t = torch.from_numpy(nodule_a).to(torch.float32)
+            nodule_t = nodule_t.unsqueeze(0)
+
+        malignant_t = torch.tensor([
+                not nodule_tup.isMalignant_bool,
+                nodule_tup.isMalignant_bool
+            ],
+            dtype=torch.long,
+        )
+
+        # log.debug([type(center_irc), center_irc])
+
+        return nodule_t, malignant_t, nodule_tup.series_uid, torch.tensor(center_irc)
+
+
+class PrepcacheLunaDataset(LunaDataset):
+    def __getitem__(self, ndx):
+        nodule_t, malignant_t, series_uid, center_t = super().__getitem__(ndx)
+        getCtSampleSize(series_uid)
+        return nodule_t, malignant_t, series_uid, center_t
+
+
+class Luna2dSegmentationDataset(Dataset):
+    def __init__(self,
+                 val_stride=0,
+                 isValSet_bool=None,
+                 series_uid=None,
+                 contextSlices_count=2,
+                 augmentation_dict=None,
+                 fullCt_bool=False,
+            ):
+        self.contextSlices_count = contextSlices_count
+        self.augmentation_dict = augmentation_dict
+
+        if series_uid:
+            self.series_list = [series_uid]
+        else:
+            self.series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
+
+        if isValSet_bool:
+            assert val_stride > 0, val_stride
+            self.series_list = self.series_list[::val_stride]
+            assert self.series_list
+        elif val_stride > 0:
+            del self.series_list[::val_stride]
+            assert self.series_list
+
+        self.sample_list = []
+        for series_uid in self.series_list:
+            if fullCt_bool:
+                self.sample_list.extend([(series_uid, ct_ndx) for ct_ndx in range(getCt(series_uid).hu_a.shape[0])])
+            else:
+                self.sample_list.extend([(series_uid, ct_ndx) for ct_ndx in range(getCtSampleSize(series_uid))])
+
+        log.info("{!r}: {} {} series, {} slices".format(
+            self,
+            len(self.series_list),
+            {None: 'general', True: 'validation', False: 'training'}[isValSet_bool],
+            len(self.sample_list),
+        ))
+
+    def __len__(self):
+        return len(self.sample_list) #// 100
+
+    def __getitem__(self, ndx):
+        if isinstance(ndx, int):
+            series_uid, sample_ndx = self.sample_list[ndx % len(self.sample_list)]
+            ct = getCt(series_uid)
+            ct_ndx = self.sample_list[sample_ndx][1]
+            useAugmentation_bool = False
+        else:
+            series_uid, ct_ndx, useAugmentation_bool = ndx
+            ct = getCt(series_uid)
+
+        ct_t = torch.zeros((self.contextSlices_count * 2 + 1 + 1, 512, 512))
+
+        start_ndx = ct_ndx - self.contextSlices_count
+        end_ndx = ct_ndx + self.contextSlices_count + 1
+        for i, context_ndx in enumerate(range(start_ndx, end_ndx)):
+            context_ndx = max(context_ndx, 0)
+            context_ndx = min(context_ndx, ct.hu_a.shape[0] - 1)
+
+            ct_t[i] = torch.from_numpy(ct.hu_a[context_ndx].astype(np.float32))
+        ct_t /= 1000
+
+        mask_tup = ct.build2dLungMask(ct_ndx)
+
+        ct_t[-1] = torch.from_numpy(mask_tup.lung_mask.astype(np.float32))
+
+        nodule_t = torch.from_numpy(
+            (mask_tup.mal_mask | mask_tup.ben_mask).astype(np.float32)
+        ).unsqueeze(0)
+        ben_t = torch.from_numpy(mask_tup.ben_mask.astype(np.float32)).unsqueeze(0)
+        mal_t = torch.from_numpy(mask_tup.mal_mask.astype(np.float32)).unsqueeze(0)
+        label_int = mal_t.max() + ben_t.max() * 2
+
+        if self.augmentation_dict and useAugmentation_bool:
+            if 'rotate' in self.augmentation_dict:
+                if random.random() > 0.5:
+                    ct_t = ct_t.rot90(1, [1, 2])
+                    nodule_t = nodule_t.rot90(1, [1, 2])
+
+            if 'flip' in self.augmentation_dict:
+                dims = [d+1 for d in range(2) if random.random() > 0.5]
+
+                if dims:
+                    ct_t = ct_t.flip(dims)
+                    nodule_t = nodule_t.flip(dims)
+
+            if 'noise' in self.augmentation_dict:
+                noise_t = torch.randn_like(ct_t)
+                noise_t *= self.augmentation_dict['noise']
+
+                ct_t += noise_t
+        return ct_t, nodule_t, label_int, ben_t, mal_t, ct.series_uid, ct_ndx
+
+
+class TrainingLuna2dSegmentationDataset(Luna2dSegmentationDataset):
+    def __init__(self, *args, batch_size=80, **kwargs):
+        self.needsShuffle_bool = True
+        self.batch_size = batch_size
+        # self.rotate_frac = 0.5 * len(self.series_list) / len(self)
+        super().__init__(*args, **kwargs)
+
+    def __len__(self):
+        return 50000
+
+    def __getitem__(self, ndx):
+        if self.needsShuffle_bool:
+            random.shuffle(self.series_list)
+            self.needsShuffle_bool = False
+
+        if isinstance(ndx, int):
+            if ndx % self.batch_size == 0:
+                self.series_list.append(self.series_list.pop(0))
+
+            series_uid = self.series_list[ndx % ctCache_depth]
+            ct = getCt(series_uid)
+
+            if ndx % 3 == 0:
+                ct_ndx = random.choice(ct.malignant_indexes or ct.benign_indexes)
+            elif ndx % 3 == 1:
+                ct_ndx = random.choice(ct.benign_indexes)
+            elif ndx % 3 == 2:
+                ct_ndx = random.choice(list(range(ct.hu_a.shape[0])))
+
+            useAugmentation_bool = True
+        else:
+            series_uid, ct_ndx, useAugmentation_bool = ndx
+
+        return super().__getitem__((series_uid, ct_ndx, useAugmentation_bool))

+ 92 - 0
p2ch14/model_cls.py

@@ -0,0 +1,92 @@
+import math
+
+import torch.nn as nn
+
+from util.logconf import logging
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+
+class LunaModel(nn.Module):
+    def __init__(self, in_channels=1, conv_channels=8):
+        super().__init__()
+
+        self.tail_batchnorm = nn.BatchNorm3d(1)
+
+        self.block1 = LunaBlock(in_channels, conv_channels)
+        self.block2 = LunaBlock(conv_channels, conv_channels * 2)
+        self.block3 = LunaBlock(conv_channels * 2, conv_channels * 4)
+        self.block4 = LunaBlock(conv_channels * 4, conv_channels * 8)
+
+        self.head_linear = nn.Linear(1152, 2)
+        self.head_softmax = nn.Softmax(dim=1)
+
+        self._init_weights()
+
+    # see also https://github.com/pytorch/pytorch/issues/18182
+    def _init_weights(self):
+        for m in self.modules():
+            if type(m) in {
+                nn.Linear,
+                nn.Conv3d,
+                nn.Conv2d,
+                nn.ConvTranspose2d,
+                nn.ConvTranspose3d,
+            }:
+                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out', nonlinearity='relu')
+                if m.bias is not None:
+                    fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
+                    bound = 1 / math.sqrt(fan_out)
+                    nn.init.normal_(m.bias, -bound, bound)
+
+
+    def forward(self, input_batch):
+        bn_output = self.tail_batchnorm(input_batch)
+
+        block_out = self.block1(bn_output)
+        block_out = self.block2(block_out)
+        block_out = self.block3(block_out)
+        block_out = self.block4(block_out)
+
+        conv_flat = block_out.view(
+            block_out.size(0),
+            -1,
+        )
+        linear_output = self.head_linear(conv_flat)
+
+        return linear_output, self.head_softmax(linear_output)
+
+
+class LunaBlock(nn.Module):
+    def __init__(self, in_channels, conv_channels):
+        super().__init__()
+
+        self.conv1 = nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True)
+        self.relu1 = nn.ReLU(inplace=True)
+        self.conv2 = nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=True)
+        self.relu2 = nn.ReLU(inplace=True)
+
+        self.maxpool = nn.MaxPool3d(2, 2)
+
+    def forward(self, input_batch):
+        block_out = self.conv1(input_batch)
+        block_out = self.relu1(block_out)
+        block_out = self.conv2(block_out)
+        block_out = self.relu2(block_out)
+
+        return self.maxpool(block_out)
+
+
+class AlternateLunaModel(LunaModel):
+    def __init__(self, in_channels=1, conv_channels=64):
+        super().__init__()
+
+        self.block1 = LunaBlock(in_channels, conv_channels)
+        self.block2 = LunaBlock(conv_channels, conv_channels // 2)
+        self.block3 = LunaBlock(conv_channels // 2, conv_channels // 4)
+        self.block4 = LunaBlock(conv_channels // 4, conv_channels // 8)
+
+        self.head_linear = nn.Linear(144, 2)

+ 46 - 0
p2ch14/model_seg.py

@@ -0,0 +1,46 @@
+import math
+
+from torch import nn as nn
+
+from util.logconf import logging
+from util.unet import UNet
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+class UNetWrapper(nn.Module):
+    def __init__(self, **kwargs):
+        super().__init__()
+
+        self.input_batchnorm = nn.BatchNorm2d(kwargs['in_channels'])
+        self.unet = UNet(**kwargs)
+        self.final = nn.Sigmoid()
+
+        self._init_weights()
+
+    def _init_weights(self):
+        init_set = {
+            nn.Conv2d,
+            nn.Conv3d,
+            nn.ConvTranspose2d,
+            nn.ConvTranspose3d,
+            nn.Linear,
+        }
+        for m in self.modules():
+            if type(m) in init_set:
+                nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu', a=0)
+                if m.bias is not None:
+                    fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
+                    bound = 1 / math.sqrt(fan_out)
+                    nn.init.normal_(m.bias, -bound, bound)
+
+
+    def forward(self, input_batch):
+        bn_output = self.input_batchnorm(input_batch)
+        un_output = self.unet(bn_output)
+        fn_output = self.final(un_output)
+
+        return fn_output
+

+ 68 - 0
p2ch14/prepcache.py

@@ -0,0 +1,68 @@
+import argparse
+import sys
+
+import numpy as np
+
+import torch.nn as nn
+from torch.autograd import Variable
+from torch.optim import SGD
+from torch.utils.data import DataLoader
+
+from util.util import enumerateWithEstimate
+from .dsets import PrepcacheLunaDataset, getCtSampleSize
+from util.logconf import logging
+# from .model import LunaModel
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+log.setLevel(logging.INFO)
+# log.setLevel(logging.DEBUG)
+
+
+class LunaPrepCacheApp(object):
+    @classmethod
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=1024,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+        # parser.add_argument('--scaled',
+        #     help="Scale the CT chunks to square voxels.",
+        #     default=False,
+        #     action='store_true',
+        # )
+
+        self.cli_args = parser.parse_args(sys_argv)
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        self.prep_dl = DataLoader(
+            PrepcacheLunaDataset(
+                sortby_str='series_uid',
+            ),
+            batch_size=self.cli_args.batch_size,
+            num_workers=self.cli_args.num_workers,
+        )
+
+        batch_iter = enumerateWithEstimate(
+            self.prep_dl,
+            "Stuffing cache",
+            start_ndx=self.prep_dl.num_workers,
+        )
+        for batch_ndx, batch_tup in batch_iter:
+            pass
+
+
+if __name__ == '__main__':
+    LunaPrepCacheApp().main()

+ 92 - 0
p2ch14/screencts.py

@@ -0,0 +1,92 @@
+import argparse
+import sys
+
+import numpy as np
+
+import torch.nn as nn
+from torch.autograd import Variable
+from torch.optim import SGD
+from torch.utils.data import Dataset, DataLoader
+
+from util.util import enumerateWithEstimate, prhist
+from .dsets import getNoduleInfoList, getCtSize, getCt
+from util.logconf import logging
+# from .model import LunaModel
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+log.setLevel(logging.INFO)
+# log.setLevel(logging.DEBUG)
+
+
+class LunaScreenCtDataset(Dataset):
+    def __init__(self):
+        self.series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
+
+    def __len__(self):
+        return len(self.series_list)
+
+    def __getitem__(self, ndx):
+        series_uid = self.series_list[ndx]
+        ct = getCt(series_uid)
+        mid_ndx = ct.hu_a.shape[0] // 2
+
+        air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask, altben_mask = ct.build2dLungMask(mid_ndx)
+
+        return series_uid, float(dense_mask.sum() / denoise_mask.sum())
+
+
+class LunaScreenCtApp(object):
+    @classmethod
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=4,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+        # parser.add_argument('--scaled',
+        #     help="Scale the CT chunks to square voxels.",
+        #     default=False,
+        #     action='store_true',
+        # )
+
+        self.cli_args = parser.parse_args(sys_argv)
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        self.prep_dl = DataLoader(
+            LunaScreenCtDataset(),
+            batch_size=self.cli_args.batch_size,
+            num_workers=self.cli_args.num_workers,
+        )
+
+        series2ratio_dict = {}
+
+        batch_iter = enumerateWithEstimate(
+            self.prep_dl,
+            "Screening CTs",
+            start_ndx=self.prep_dl.num_workers,
+        )
+        for batch_ndx, batch_tup in batch_iter:
+            series_list, ratio_list = batch_tup
+            for series_uid, ratio_float in zip(series_list, ratio_list):
+                series2ratio_dict[series_uid] = ratio_float
+            # break
+
+        prhist(list(series2ratio_dict.values()))
+
+
+
+
+if __name__ == '__main__':
+    LunaScreenCtApp().main()

+ 462 - 0
p2ch14/train_cls.py

@@ -0,0 +1,462 @@
+import argparse
+import datetime
+import os
+import sys
+
+import numpy as np
+
+from tensorboardX import SummaryWriter
+
+import torch
+import torch.nn as nn
+from torch.optim import SGD, Adam
+from torch.utils.data import DataLoader
+
+from util.util import enumerateWithEstimate
+from .dsets import LunaDataset
+from .model_cls import LunaModel
+
+from util.logconf import logging
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+# Used for computeBatchLoss and logMetrics to index into metrics_t/metrics_a
+METRICS_LABEL_NDX=0
+METRICS_PRED_NDX=1
+METRICS_LOSS_NDX=2
+METRICS_SIZE = 3
+
+class LunaTrainingApp(object):
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=32,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+        parser.add_argument('--epochs',
+            help='Number of epochs to train for',
+            default=1,
+            type=int,
+        )
+        parser.add_argument('--balanced',
+            help="Balance the training data to half benign, half malignant.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augmented',
+            help="Augment the training data.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-flip',
+            help="Augment the training data by randomly flipping the data left-right, up-down, and front-back.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-offset',
+            help="Augment the training data by randomly offsetting the data slightly along the X and Y axes.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-scale',
+            help="Augment the training data by randomly increasing or decreasing the size of the nodule.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-rotate',
+            help="Augment the training data by randomly rotating the data around the head-foot axis.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-noise',
+            help="Augment the training data by randomly adding noise to the data.",
+            action='store_true',
+            default=False,
+        )
+
+        parser.add_argument('--tb-prefix',
+            default='p2ch13',
+            help="Data prefix to use for Tensorboard run. Defaults to chapter.",
+        )
+        parser.add_argument('comment',
+            help="Comment suffix for Tensorboard run.",
+            nargs='?',
+            default='none',
+        )
+
+        self.cli_args = parser.parse_args(sys_argv)
+        self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
+
+        self.totalTrainingSamples_count = 0
+        self.trn_writer = None
+        self.val_writer = None
+
+        self.augmentation_dict = {}
+        if self.cli_args.augmented or self.cli_args.augment_flip:
+            self.augmentation_dict['flip'] = True
+        if self.cli_args.augmented or self.cli_args.augment_offset:
+            self.augmentation_dict['offset'] = 0.1
+        if self.cli_args.augmented or self.cli_args.augment_scale:
+            self.augmentation_dict['scale'] = 0.2
+        if self.cli_args.augmented or self.cli_args.augment_rotate:
+            self.augmentation_dict['rotate'] = True
+        if self.cli_args.augmented or self.cli_args.augment_noise:
+            self.augmentation_dict['noise'] = 25.0
+
+        self.use_cuda = torch.cuda.is_available()
+        self.device = torch.device("cuda" if self.use_cuda else "cpu")
+
+        self.model = self.initModel()
+        self.optimizer = self.initOptimizer()
+
+
+    def initModel(self):
+        model = LunaModel()
+        if self.use_cuda:
+            if torch.cuda.device_count() > 1:
+                model = nn.DataParallel(model)
+            model = model.to(self.device)
+        return model
+
+    def initOptimizer(self):
+        return SGD(self.model.parameters(), lr=0.001, momentum=0.99)
+#         return Adam(self.model.parameters())
+
+    def initTrainDl(self):
+        train_ds = LunaDataset(
+            val_stride=10,
+            isValSet_bool=False,
+            ratio_int=int(self.cli_args.balanced),
+            augmentation_dict=self.augmentation_dict,
+        )
+
+        train_dl = DataLoader(
+            train_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return train_dl
+
+    def initValDl(self):
+        val_ds = LunaDataset(
+            val_stride=10,
+            isValSet_bool=True,
+        )
+
+        val_dl = DataLoader(
+            val_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return val_dl
+
+    def initTensorboardWriters(self):
+        if self.trn_writer is None:
+            log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
+
+            self.trn_writer = SummaryWriter(log_dir=log_dir + '_trn_cls_' + self.cli_args.comment)
+            self.val_writer = SummaryWriter(log_dir=log_dir + '_val_cls_' + self.cli_args.comment)
+
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        train_dl = self.initTrainDl()
+        val_dl = self.initValDl()
+
+        best_score = 0.0
+
+        for epoch_ndx in range(1, self.cli_args.epochs + 1):
+
+            log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
+                epoch_ndx,
+                self.cli_args.epochs,
+                len(train_dl),
+                len(val_dl),
+                self.cli_args.batch_size,
+                (torch.cuda.device_count() if self.use_cuda else 1),
+            ))
+
+            trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
+            self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)
+
+            valMetrics_t = self.doValidation(epoch_ndx, val_dl)
+            score = self.logMetrics(epoch_ndx, 'val', valMetrics_t)
+            best_score = max(score, best_score)
+
+            self.saveModel('cls', epoch_ndx, score == best_score)
+
+        if hasattr(self, 'trn_writer'):
+            self.trn_writer.close()
+            self.val_writer.close()
+
+
+    def doTraining(self, epoch_ndx, train_dl):
+        self.model.train()
+        train_dl.dataset.shuffleSamples()
+        trnMetrics_g = torch.zeros(
+                METRICS_SIZE,
+                len(train_dl.dataset),
+            ).to(self.device)
+        batch_iter = enumerateWithEstimate(
+            train_dl,
+            "E{} Training".format(epoch_ndx),
+            start_ndx=train_dl.num_workers,
+        )
+        for batch_ndx, batch_tup in batch_iter:
+            self.optimizer.zero_grad()
+
+            loss_var = self.computeBatchLoss(
+                batch_ndx,
+                batch_tup,
+                train_dl.batch_size,
+                trnMetrics_g
+            )
+
+            loss_var.backward()
+            self.optimizer.step()
+            del loss_var
+
+        self.totalTrainingSamples_count += trnMetrics_g.size(1)
+
+        return trnMetrics_g.to('cpu')
+
+
+    def doValidation(self, epoch_ndx, val_dl):
+        with torch.no_grad():
+            self.model.eval()
+            valMetrics_g = torch.zeros(
+                    METRICS_SIZE,
+                    len(val_dl.dataset),
+                ).to(self.device)
+            batch_iter = enumerateWithEstimate(
+                val_dl,
+                "E{} Validation ".format(epoch_ndx),
+                start_ndx=val_dl.num_workers,
+            )
+            for batch_ndx, batch_tup in batch_iter:
+                self.computeBatchLoss(
+                    batch_ndx,
+                    batch_tup,
+                    val_dl.batch_size,
+                    valMetrics_g,
+                )
+
+        return valMetrics_g.to('cpu')
+
+
+
+    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g):
+        input_t, label_t, _series_list, _center_list = batch_tup
+
+        input_g = input_t.to(self.device, non_blocking=True)
+        label_g = label_t.to(self.device, non_blocking=True)
+
+        logits_g, probability_g = self.model(input_g)
+
+        loss_func = nn.CrossEntropyLoss(reduction='none')
+        loss_g = loss_func(logits_g, label_g[:,1])
+
+        start_ndx = batch_ndx * batch_size
+        end_ndx = start_ndx + label_t.size(0)
+
+        metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_g[:,1]
+        metrics_g[METRICS_PRED_NDX, start_ndx:end_ndx] = probability_g[:,1]
+        metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_g
+
+        return loss_g.mean()
+
+
+    def logMetrics(
+            self,
+            epoch_ndx,
+            mode_str,
+            metrics_t,
+    ):
+        self.initTensorboardWriters()
+        log.info("E{} {}".format(
+            epoch_ndx,
+            type(self).__name__,
+        ))
+
+        metrics_a = metrics_t.cpu().detach().numpy()
+#         assert np.isfinite(metrics_a).all()
+
+        benLabel_mask = metrics_a[METRICS_LABEL_NDX] <= 0.5
+        benPred_mask = metrics_a[METRICS_PRED_NDX] <= 0.5
+
+        malLabel_mask = ~benLabel_mask
+        malPred_mask = ~benPred_mask
+
+        benLabel_count = benLabel_mask.sum()
+        malLabel_count = malLabel_mask.sum()
+
+        trueNeg_count = benCorrect_count = (benLabel_mask & benPred_mask).sum()
+        truePos_count = malCorrect_count = (malLabel_mask & malPred_mask).sum()
+
+        falsePos_count = benLabel_count - benCorrect_count
+        falseNeg_count = malLabel_count - malCorrect_count
+
+        metrics_dict = {}
+        metrics_dict['loss/all'] = metrics_a[METRICS_LOSS_NDX].mean()
+        metrics_dict['loss/ben'] = metrics_a[METRICS_LOSS_NDX, benLabel_mask].mean()
+        metrics_dict['loss/mal'] = metrics_a[METRICS_LOSS_NDX, malLabel_mask].mean()
+
+        metrics_dict['correct/all'] = (malCorrect_count + benCorrect_count) / metrics_a.shape[1] * 100
+        metrics_dict['correct/ben'] = (benCorrect_count) / benLabel_count * 100
+        metrics_dict['correct/mal'] = (malCorrect_count) / malLabel_count * 100
+
+        precision = metrics_dict['pr/precision'] = \
+            truePos_count / (truePos_count + falsePos_count)
+        recall = metrics_dict['pr/recall'] = \
+            truePos_count / (truePos_count + falseNeg_count)
+
+        metrics_dict['pr/f1_score'] = 2 * (precision * recall)\
+            / (precision + recall)
+
+        log.info(
+            ("E{} {:8} "
+                 + "{loss/all:.4f} loss, "
+                 + "{correct/all:-5.1f}% correct, "
+                 + "{pr/precision:.4f} precision, "
+                 + "{pr/recall:.4f} recall, "
+                 + "{pr/f1_score:.4f} f1 score"
+            ).format(
+                epoch_ndx,
+                mode_str,
+                **metrics_dict,
+            )
+        )
+        log.info(
+            ("E{} {:8} "
+                 + "{loss/ben:.4f} loss, "
+                 + "{correct/ben:-5.1f}% correct "
+                 + "({benCorrect_count:} of {benLabel_count:})"
+            ).format(
+                epoch_ndx,
+                mode_str + '_ben',
+                benCorrect_count=benCorrect_count,
+                benLabel_count=benLabel_count,
+                **metrics_dict,
+            )
+        )
+        log.info(
+            ("E{} {:8} "
+                 + "{loss/mal:.4f} loss, "
+                 + "{correct/mal:-5.1f}% correct "
+                 + "({malCorrect_count:} of {malLabel_count:})"
+            ).format(
+                epoch_ndx,
+                mode_str + '_mal',
+                malCorrect_count=malCorrect_count,
+                malLabel_count=malLabel_count,
+                **metrics_dict,
+            )
+        )
+        writer = getattr(self, mode_str + '_writer')
+
+        for key, value in metrics_dict.items():
+            writer.add_scalar(key, value, self.totalTrainingSamples_count)
+
+        writer.add_pr_curve(
+            'pr',
+            metrics_a[METRICS_LABEL_NDX],
+            metrics_a[METRICS_PRED_NDX],
+            self.totalTrainingSamples_count,
+        )
+
+        bins = [x/50.0 for x in range(51)]
+
+        benHist_mask = benLabel_mask & (metrics_a[METRICS_PRED_NDX] > 0.01)
+        malHist_mask = malLabel_mask & (metrics_a[METRICS_PRED_NDX] < 0.99)
+
+        if benHist_mask.any():
+            writer.add_histogram(
+                'is_ben',
+                metrics_a[METRICS_PRED_NDX, benHist_mask],
+                self.totalTrainingSamples_count,
+                bins=bins,
+            )
+        if malHist_mask.any():
+            writer.add_histogram(
+                'is_mal',
+                metrics_a[METRICS_PRED_NDX, malHist_mask],
+                self.totalTrainingSamples_count,
+                bins=bins,
+            )
+
+        score = 1 \
+            + metrics_dict['pr/f1_score'] \
+            - metrics_dict['loss/mal'] * 0.01 \
+            - metrics_dict['loss/all'] * 0.0001
+
+        return score
+
+    def saveModel(self, type_str, epoch_ndx, isBest=False):
+        file_path = os.path.join(
+            'data-unversioned',
+            'part2',
+            'models',
+            self.cli_args.tb_prefix,
+            '{}_{}_{}.{}.state'.format(
+                type_str,
+                self.time_str,
+                self.cli_args.comment,
+                self.totalTrainingSamples_count,
+            )
+        )
+
+        os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
+
+        model = self.model
+        if hasattr(model, 'module'):
+            model = model.module
+
+        state = {
+            'model_state': model.state_dict(),
+            'model_name': type(model).__name__,
+            'optimizer_state' : self.optimizer.state_dict(),
+            'optimizer_name': type(self.optimizer).__name__,
+            'epoch': epoch_ndx,
+            'totalTrainingSamples_count': self.totalTrainingSamples_count,
+            # 'resumed_from': self.cli_args.resume,
+        }
+        torch.save(state, file_path)
+
+        log.debug("Saved model params to {}".format(file_path))
+
+        if isBest:
+            file_path = os.path.join(
+                'data-unversioned',
+                'part2',
+                'models',
+                self.cli_args.tb_prefix,
+                '{}_{}_{}.{}.state'.format(
+                    type_str,
+                    self.time_str,
+                    self.cli_args.comment,
+                    'best',
+                )
+            )
+            torch.save(state, file_path)
+
+            log.debug("Saved model params to {}".format(file_path))
+
+if __name__ == '__main__':
+    sys.exit(LunaTrainingApp().main() or 0)

+ 580 - 0
p2ch14/train_seg.py

@@ -0,0 +1,580 @@
+import argparse
+import datetime
+import os
+import socket
+import sys
+
+import numpy as np
+from tensorboardX import SummaryWriter
+
+import torch
+import torch.nn as nn
+import torch.optim
+
+from torch.optim import SGD, Adam
+from torch.utils.data import DataLoader
+
+from util.util import enumerateWithEstimate
+from .dsets import Luna2dSegmentationDataset, TrainingLuna2dSegmentationDataset, getCt
+from util.logconf import logging
+from util.util import xyz2irc
+from .model_seg import UNetWrapper
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+# Used for computeClassificationLoss and logMetrics to index into metrics_t/metrics_a
+METRICS_LABEL_NDX = 0
+METRICS_LOSS_NDX = 1
+METRICS_MAL_LOSS_NDX = 2
+# METRICS_ALL_LOSS_NDX = 3
+
+METRICS_MTP_NDX = 4
+METRICS_MFN_NDX = 5
+METRICS_MFP_NDX = 6
+METRICS_ATP_NDX = 7
+METRICS_AFN_NDX = 8
+METRICS_AFP_NDX = 9
+
+METRICS_SIZE = 10
+
+class LunaTrainingApp(object):
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=16,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+        parser.add_argument('--epochs',
+            help='Number of epochs to train for',
+            default=1,
+            type=int,
+        )
+
+        parser.add_argument('--augmented',
+            help="Augment the training data.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-flip',
+            help="Augment the training data by randomly flipping the data left-right, up-down, and front-back.",
+            action='store_true',
+            default=False,
+        )
+        # parser.add_argument('--augment-offset',
+        #     help="Augment the training data by randomly offsetting the data slightly along the X and Y axes.",
+        #     action='store_true',
+        #     default=False,
+        # )
+        # parser.add_argument('--augment-scale',
+        #     help="Augment the training data by randomly increasing or decreasing the size of the nodule.",
+        #     action='store_true',
+        #     default=False,
+        # )
+        parser.add_argument('--augment-rotate',
+            help="Augment the training data by randomly rotating the data around the head-foot axis.",
+            action='store_true',
+            default=False,
+        )
+        parser.add_argument('--augment-noise',
+            help="Augment the training data by randomly adding noise to the data.",
+            action='store_true',
+            default=False,
+        )
+
+        parser.add_argument('--tb-prefix',
+            default='p2ch13',
+            help="Data prefix to use for Tensorboard run. Defaults to chapter.",
+        )
+
+        parser.add_argument('comment',
+            help="Comment suffix for Tensorboard run.",
+            nargs='?',
+            default='none',
+        )
+
+        self.cli_args = parser.parse_args(sys_argv)
+        self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
+        self.totalTrainingSamples_count = 0
+        self.trn_writer = None
+        self.val_writer = None
+
+        augmentation_dict = {}
+        if self.cli_args.augmented or self.cli_args.augment_flip:
+            augmentation_dict['flip'] = True
+        if self.cli_args.augmented or self.cli_args.augment_rotate:
+            augmentation_dict['rotate'] = True
+        if self.cli_args.augmented or self.cli_args.augment_noise:
+            augmentation_dict['noise'] = 0.025
+        self.augmentation_dict = augmentation_dict
+
+        self.use_cuda = torch.cuda.is_available()
+        self.device = torch.device("cuda" if self.use_cuda else "cpu")
+
+        self.model = self.initModel()
+        self.optimizer = self.initOptimizer()
+
+
+    def initModel(self):
+        model = UNetWrapper(
+            in_channels=8,
+            n_classes=1,
+            depth=4,
+            wf=3,
+            padding=True,
+            batch_norm=True,
+            up_mode='upconv',
+        )
+
+        if self.use_cuda:
+            if torch.cuda.device_count() > 1:
+                model = nn.DataParallel(model)
+            model = model.to(self.device)
+        return model
+
+    def initOptimizer(self):
+        return SGD(self.model.parameters(), lr=0.001, momentum=0.99)
+        # return Adam(self.model.parameters())
+
+
+    def initTrainDl(self):
+        train_ds = TrainingLuna2dSegmentationDataset(
+            val_stride=10,
+            isValSet_bool=False,
+            contextSlices_count=3,
+            augmentation_dict=self.augmentation_dict,
+        )
+
+        train_dl = DataLoader(
+            train_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return train_dl
+
+    def initValDl(self):
+        val_ds = Luna2dSegmentationDataset(
+            val_stride=10,
+            isValSet_bool=True,
+            contextSlices_count=3,
+        )
+
+        val_dl = DataLoader(
+            val_ds,
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        return val_dl
+
+    def initTensorboardWriters(self):
+        if self.trn_writer is None:
+            log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
+
+            self.trn_writer = SummaryWriter(log_dir=log_dir + '_trn_seg_' + self.cli_args.comment)
+            self.val_writer = SummaryWriter(log_dir=log_dir + '_val_seg_' + self.cli_args.comment)
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        train_dl = self.initTrainDl()
+        val_dl = self.initValDl()
+
+        # self.logModelMetrics(self.model)
+
+        best_score = 0.0
+        for epoch_ndx in range(1, self.cli_args.epochs + 1):
+            log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
+                epoch_ndx,
+                self.cli_args.epochs,
+                len(train_dl),
+                len(val_dl),
+                self.cli_args.batch_size,
+                (torch.cuda.device_count() if self.use_cuda else 1),
+            ))
+
+            trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
+            self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)
+            self.logImages(epoch_ndx, 'trn', train_dl)
+            self.logImages(epoch_ndx, 'val', val_dl)
+            # self.logModelMetrics(self.model)
+
+            valMetrics_t = self.doValidation(epoch_ndx, val_dl)
+            score = self.logMetrics(epoch_ndx, 'val', valMetrics_t)
+            best_score = max(score, best_score)
+
+            self.saveModel('seg', epoch_ndx, score == best_score)
+
+        if hasattr(self, 'trn_writer'):
+            self.trn_writer.close()
+            self.val_writer.close()
+
+    def doTraining(self, epoch_ndx, train_dl):
+        trnMetrics_g = torch.zeros(METRICS_SIZE, len(train_dl.dataset)).to(self.device)
+        self.model.train()
+
+        batch_iter = enumerateWithEstimate(
+            train_dl,
+            "E{} Training".format(epoch_ndx),
+            start_ndx=train_dl.num_workers,
+        )
+        for batch_ndx, batch_tup in batch_iter:
+            self.optimizer.zero_grad()
+
+            loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trnMetrics_g)
+            loss_var.backward()
+
+            self.optimizer.step()
+            del loss_var
+
+        self.totalTrainingSamples_count += trnMetrics_g.size(1)
+
+        return trnMetrics_g.to('cpu')
+
+    def doValidation(self, epoch_ndx, val_dl):
+        with torch.no_grad():
+            valMetrics_g = torch.zeros(METRICS_SIZE, len(val_dl.dataset)).to(self.device)
+            self.model.eval()
+
+            batch_iter = enumerateWithEstimate(
+                val_dl,
+                "E{} Validation ".format(epoch_ndx),
+                start_ndx=val_dl.num_workers,
+            )
+            for batch_ndx, batch_tup in batch_iter:
+                self.computeBatchLoss(batch_ndx, batch_tup, val_dl.batch_size, valMetrics_g)
+
+        return valMetrics_g.to('cpu')
+
+    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g):
+        input_t, label_t, label_list, ben_t, mal_t, _, _ = batch_tup
+
+        input_g = input_t.to(self.device, non_blocking=True)
+        label_g = label_t.to(self.device, non_blocking=True)
+        mal_g = mal_t.to(self.device, non_blocking=True)
+        ben_g = ben_t.to(self.device, non_blocking=True)
+
+        start_ndx = batch_ndx * batch_size
+        end_ndx = start_ndx + label_t.size(0)
+        intersectionSum = lambda a, b: (a * b).view(a.size(0), -1).sum(dim=1)
+
+        prediction_g = self.model(input_g)
+        diceLoss_g = self.diceLoss(label_g, prediction_g)
+
+        with torch.no_grad():
+            malLoss_g = self.diceLoss(mal_g, prediction_g * mal_g, p=True)
+            predictionBool_g = (prediction_g > 0.5).to(torch.float32)
+
+            metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_list
+            metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_g
+            metrics_g[METRICS_MAL_LOSS_NDX, start_ndx:end_ndx] = malLoss_g
+
+            malPred_g = predictionBool_g * mal_g
+            tp = intersectionSum(    mal_g,       malPred_g)
+            fn = intersectionSum(    mal_g,   1 - malPred_g)
+
+            metrics_g[METRICS_MTP_NDX, start_ndx:end_ndx] = tp
+            metrics_g[METRICS_MFN_NDX, start_ndx:end_ndx] = fn
+
+            del malPred_g, tp, fn
+
+            tp = intersectionSum(    label_g,     predictionBool_g)
+            fn = intersectionSum(    label_g, 1 - predictionBool_g)
+            fp = intersectionSum(1 - label_g,     predictionBool_g)
+
+            metrics_g[METRICS_ATP_NDX, start_ndx:end_ndx] = tp
+            metrics_g[METRICS_AFN_NDX, start_ndx:end_ndx] = fn
+            metrics_g[METRICS_AFP_NDX, start_ndx:end_ndx] = fp
+
+            del tp, fn, fp
+
+        return diceLoss_g.mean()
+
+    # def diceLoss(self, label_g, prediction_g, epsilon=0.01, p=False):
+    def diceLoss(self, label_g, prediction_g, epsilon=1, p=False):
+        sum_dim1 = lambda t: t.view(t.size(0), -1).sum(dim=1)
+
+        diceLabel_g = sum_dim1(label_g)
+        dicePrediction_g = sum_dim1(prediction_g)
+        diceCorrect_g = sum_dim1(prediction_g * label_g)
+
+        epsilon_g = torch.ones_like(diceCorrect_g) * epsilon
+        diceLoss_g = 1 - (2 * diceCorrect_g + epsilon_g) \
+            / (dicePrediction_g + diceLabel_g + epsilon_g)
+
+        if p and diceLoss_g.mean() < 0:
+            correct_tmp = prediction_g * label_g
+
+            log.debug([])
+            log.debug(['diceCorrect_g   ', diceCorrect_g[0].item(), correct_tmp[0].min().item(), correct_tmp[0].mean().item(), correct_tmp[0].max().item(), correct_tmp.shape])
+            log.debug(['dicePrediction_g', dicePrediction_g[0].item(), prediction_g[0].min().item(), prediction_g[0].mean().item(), prediction_g[0].max().item(), prediction_g.shape])
+            log.debug(['diceLabel_g     ', diceLabel_g[0].item(), label_g[0].min().item(), label_g[0].mean().item(), label_g[0].max().item(), label_g.shape])
+            log.debug(['2*diceCorrect_g ', 2 * diceCorrect_g[0].item()])
+            log.debug(['Prediction + Label      ', dicePrediction_g[0].item()])
+            log.debug(['diceLoss_g      ', diceLoss_g[0].item()])
+            assert False
+
+        return diceLoss_g
+
+
+    def logImages(self, epoch_ndx, mode_str, dl):
+        images_iter = sorted(dl.dataset.series_list)[:12]
+        for series_ndx, series_uid in enumerate(images_iter):
+            ct = getCt(series_uid)
+
+            for slice_ndx in range(6):
+                ct_ndx = slice_ndx * ct.hu_a.shape[0] // 5
+                ct_ndx = min(ct_ndx, ct.hu_a.shape[0] - 1)
+                sample_tup = dl.dataset[(series_uid, ct_ndx, False)]
+
+                ct_t, nodule_t, _, ben_t, mal_t, _, _ = sample_tup
+
+                ct_t[:-1,:,:] += 1
+                ct_t[:-1,:,:] /= 2
+
+                input_g = ct_t.to(self.device)
+                label_g = nodule_t.to(self.device)
+
+                prediction_g = self.model(input_g.unsqueeze(0))[0]
+                prediction_a = prediction_g.to('cpu').detach().numpy()
+                label_a = nodule_t.numpy()
+                ben_a = ben_t.numpy()
+                mal_a = mal_t.numpy()
+
+                ctSlice_a = ct_t[dl.dataset.contextSlices_count].numpy()
+
+                image_a = np.zeros((512, 512, 3), dtype=np.float32)
+                image_a[:,:,:] = ctSlice_a.reshape((512,512,1))
+                image_a[:,:,0] += prediction_a[0] * (1 - label_a[0])
+                image_a[:,:,1] += prediction_a[0] * mal_a[0]
+                image_a[:,:,2] += prediction_a[0] * ben_a[0]
+                image_a *= 0.5
+                image_a[image_a < 0] = 0
+                image_a[image_a > 1] = 1
+
+                writer = getattr(self, mode_str + '_writer')
+                writer.add_image(
+                    '{}/{}_prediction_{}'.format(
+                        mode_str,
+                        series_ndx,
+                        slice_ndx,
+                    ),
+                    image_a,
+                    self.totalTrainingSamples_count,
+                    dataformats='HWC',
+                )
+
+                # self.diceLoss(label_g, prediction_g, p=True)
+
+                if epoch_ndx == 1:
+                    image_a = np.zeros((512, 512, 3), dtype=np.float32)
+                    image_a[:,:,:] = ctSlice_a.reshape((512,512,1))
+                    image_a[:,:,0] += (1 - label_a[0]) * ct_t[-1].numpy() # Red
+                    image_a[:,:,1] += mal_a[0]  # Green
+                    image_a[:,:,2] += ben_a[0]  # Blue
+
+                    image_a *= 0.5
+                    image_a[image_a < 0] = 0
+                    image_a[image_a > 1] = 1
+                    writer.add_image(
+                        '{}/{}_label_{}'.format(
+                            mode_str,
+                            series_ndx,
+                            slice_ndx,
+                        ),
+                        image_a,
+                        self.totalTrainingSamples_count,
+                        dataformats='HWC',
+                    )
+
+
+    def logMetrics(self,
+        epoch_ndx,
+        mode_str,
+        metrics_t,
+    ):
+        log.info("E{} {}".format(
+            epoch_ndx,
+            type(self).__name__,
+        ))
+
+        metrics_a = metrics_t.cpu().detach().numpy()
+        sum_a = metrics_a.sum(axis=1)
+        assert np.isfinite(metrics_a).all()
+
+        malLabel_mask = (metrics_a[METRICS_LABEL_NDX] == 1) | (metrics_a[METRICS_LABEL_NDX] == 3)
+
+        # allLabel_mask = (metrics_a[METRICS_LABEL_NDX] == 2) | (metrics_a[METRICS_LABEL_NDX] == 3)
+
+        allLabel_count = sum_a[METRICS_ATP_NDX] + sum_a[METRICS_AFN_NDX]
+        malLabel_count = sum_a[METRICS_MTP_NDX] + sum_a[METRICS_MFN_NDX]
+
+        # allCorrect_count = sum_a[METRICS_ATP_NDX]
+        # malCorrect_count = sum_a[METRICS_MTP_NDX]
+#
+#             falsePos_count = allLabel_count - allCorrect_count
+#             falseNeg_count = malLabel_count - malCorrect_count
+
+
+        metrics_dict = {}
+        metrics_dict['loss/all'] = metrics_a[METRICS_LOSS_NDX].mean()
+        metrics_dict['loss/mal'] = np.nan_to_num(metrics_a[METRICS_MAL_LOSS_NDX, malLabel_mask].mean())
+        # metrics_dict['loss/all'] = metrics_a[METRICS_ALL_LOSS_NDX, allLabel_mask].mean()
+
+        # metrics_dict['correct/mal'] = sum_a[METRICS_MTP_NDX] / (sum_a[METRICS_MTP_NDX] + sum_a[METRICS_MFN_NDX]) * 100
+        # metrics_dict['correct/all'] = sum_a[METRICS_ATP_NDX] / (sum_a[METRICS_ATP_NDX] + sum_a[METRICS_AFN_NDX]) * 100
+
+        metrics_dict['percent_all/tp'] = sum_a[METRICS_ATP_NDX] / (allLabel_count or 1) * 100
+        metrics_dict['percent_all/fn'] = sum_a[METRICS_AFN_NDX] / (allLabel_count or 1) * 100
+        metrics_dict['percent_all/fp'] = sum_a[METRICS_AFP_NDX] / (allLabel_count or 1) * 100
+
+        metrics_dict['percent_mal/tp'] = sum_a[METRICS_MTP_NDX] / (malLabel_count or 1) * 100
+        metrics_dict['percent_mal/fn'] = sum_a[METRICS_MFN_NDX] / (malLabel_count or 1) * 100
+
+        precision = metrics_dict['pr/precision'] = sum_a[METRICS_ATP_NDX] \
+            / ((sum_a[METRICS_ATP_NDX] + sum_a[METRICS_AFP_NDX]) or 1)
+        recall    = metrics_dict['pr/recall']    = sum_a[METRICS_ATP_NDX] \
+            / ((sum_a[METRICS_ATP_NDX] + sum_a[METRICS_AFN_NDX]) or 1)
+
+        metrics_dict['pr/f1_score'] = 2 * (precision * recall) \
+            / ((precision + recall) or 1)
+
+        log.info(("E{} {:8} "
+                 + "{loss/all:.4f} loss, "
+                 + "{pr/precision:.4f} precision, "
+                 + "{pr/recall:.4f} recall, "
+                 + "{pr/f1_score:.4f} f1 score"
+                  ).format(
+            epoch_ndx,
+            mode_str,
+            **metrics_dict,
+        ))
+        log.info(("E{} {:8} "
+                  + "{loss/all:.4f} loss, "
+                  + "{percent_all/tp:-5.1f}% tp, {percent_all/fn:-5.1f}% fn, {percent_all/fp:-9.1f}% fp"
+                 # + "{correct/all:-5.1f}% correct ({allCorrect_count:} of {allLabel_count:})"
+        ).format(
+            epoch_ndx,
+            mode_str + '_all',
+            # allCorrect_count=allCorrect_count,
+            # allLabel_count=allLabel_count,
+            **metrics_dict,
+        ))
+
+        log.info(("E{} {:8} "
+                  + "{loss/mal:.4f} loss, "
+                  + "{percent_mal/tp:-5.1f}% tp, {percent_mal/fn:-5.1f}% fn"
+                 # + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
+        ).format(
+            epoch_ndx,
+            mode_str + '_mal',
+            # malCorrect_count=malCorrect_count,
+            # malLabel_count=malLabel_count,
+            **metrics_dict,
+        ))
+
+        self.initTensorboardWriters()
+        writer = getattr(self, mode_str + '_writer')
+
+        prefix_str = 'seg_'
+
+        for key, value in metrics_dict.items():
+            writer.add_scalar(prefix_str + key, value, self.totalTrainingSamples_count)
+
+        score = 1 \
+            - metrics_dict['loss/mal'] \
+            + metrics_dict['pr/f1_score'] \
+            - metrics_dict['pr/recall'] * 0.01 \
+            - metrics_dict['loss/all']  * 0.0001
+
+        return score
+
+    # def logModelMetrics(self, model):
+    #     writer = getattr(self, 'trn_writer')
+    #
+    #     model = getattr(model, 'module', model)
+    #
+    #     for name, param in model.named_parameters():
+    #         if param.requires_grad:
+    #             min_data = float(param.data.min())
+    #             max_data = float(param.data.max())
+    #             max_extent = max(abs(min_data), abs(max_data))
+    #
+    #             # bins = [x/50*max_extent for x in range(-50, 51)]
+    #
+    #             writer.add_histogram(
+    #                 name.rsplit('.', 1)[-1] + '/' + name,
+    #                 param.data.cpu().numpy(),
+    #                 # metrics_a[METRICS_PRED_NDX, benHist_mask],
+    #                 self.totalTrainingSamples_count,
+    #                 # bins=bins,
+    #             )
+    #
+    #             # print name, param.data
+
+    def saveModel(self, type_str, epoch_ndx, isBest=False):
+        file_path = os.path.join(
+            'data-unversioned',
+            'part2',
+            'models',
+            self.cli_args.tb_prefix,
+            '{}_{}_{}.{}.state'.format(
+                type_str,
+                self.time_str,
+                self.cli_args.comment,
+                self.totalTrainingSamples_count,
+            )
+        )
+
+        os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
+
+        model = self.model
+        if hasattr(model, 'module'):
+            model = model.module
+
+        state = {
+            'model_state': model.state_dict(),
+            'model_name': type(model).__name__,
+            'optimizer_state' : self.optimizer.state_dict(),
+            'optimizer_name': type(self.optimizer).__name__,
+            'epoch': epoch_ndx,
+            'totalTrainingSamples_count': self.totalTrainingSamples_count,
+        }
+        torch.save(state, file_path)
+
+        log.debug("Saved model params to {}".format(file_path))
+
+        if isBest:
+            file_path = os.path.join(
+                'data-unversioned',
+                'part2',
+                'models',
+                self.cli_args.tb_prefix,
+                '{}_{}_{}.{}.state'.format(
+                    type_str,
+                    self.time_str,
+                    self.cli_args.comment,
+                    'best',
+                )
+            )
+            torch.save(state, file_path)
+
+            log.debug("Saved model params to {}".format(file_path))
+
+
+if __name__ == '__main__':
+    sys.exit(LunaTrainingApp().main() or 0)

+ 99 - 0
p2ch14/vis.py

@@ -0,0 +1,99 @@
+import matplotlib
+matplotlib.use('nbagg')
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+from p2ch14.dsets import Ct, LunaDataset
+
+clim=(-1000.0, 300)
+
+def findMalignantSamples(start_ndx=0, limit=100):
+    ds = LunaDataset(sortby_str='malignancy_size')
+
+    malignantSample_list = []
+    for sample_tup in ds.noduleInfo_list:
+        if sample_tup.isMalignant_bool:
+            print(len(malignantSample_list), sample_tup)
+            malignantSample_list.append(sample_tup)
+
+        if len(malignantSample_list) >= limit:
+            break
+
+    return malignantSample_list
+
+def showNodule(series_uid, batch_ndx=None, **kwargs):
+    ds = LunaDataset(series_uid=series_uid, **kwargs)
+    malignant_list = [i for i, x in enumerate(ds.noduleInfo_list) if x.isMalignant_bool]
+
+    if batch_ndx is None:
+        if malignant_list:
+            batch_ndx = malignant_list[0]
+        else:
+            print("Warning: no malignant samples found; using first non-malignant sample.")
+            batch_ndx = 0
+
+    ct = Ct(series_uid)
+    ct_t, malignant_t, series_uid, center_irc = ds[batch_ndx]
+    ct_a = ct_t[0].numpy()
+
+    fig = plt.figure(figsize=(30, 50))
+
+    group_list = [
+        [9, 11, 13],
+        [15, 16, 17],
+        [19, 21, 23],
+    ]
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 1)
+    subplot.set_title('index {}'.format(int(center_irc.index)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct.hu_a[int(center_irc.index)], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 2)
+    subplot.set_title('row {}'.format(int(center_irc.row)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct.hu_a[:,int(center_irc.row)], clim=clim, cmap='gray')
+    plt.gca().invert_yaxis()
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 3)
+    subplot.set_title('col {}'.format(int(center_irc.col)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct.hu_a[:,:,int(center_irc.col)], clim=clim, cmap='gray')
+    plt.gca().invert_yaxis()
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 4)
+    subplot.set_title('index {}'.format(int(center_irc.index)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct_a[ct_a.shape[0]//2], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 5)
+    subplot.set_title('row {}'.format(int(center_irc.row)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct_a[:,ct_a.shape[1]//2], clim=clim, cmap='gray')
+    plt.gca().invert_yaxis()
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 6)
+    subplot.set_title('col {}'.format(int(center_irc.col)), fontsize=30)
+    for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+        label.set_fontsize(20)
+    plt.imshow(ct_a[:,:,ct_a.shape[2]//2], clim=clim, cmap='gray')
+    plt.gca().invert_yaxis()
+
+    for row, index_list in enumerate(group_list):
+        for col, index in enumerate(index_list):
+            subplot = fig.add_subplot(len(group_list) + 2, 3, row * 3 + col + 7)
+            subplot.set_title('slice {}'.format(index), fontsize=30)
+            for label in (subplot.get_xticklabels() + subplot.get_yticklabels()):
+                label.set_fontsize(20)
+            plt.imshow(ct_a[index], clim=clim, cmap='gray')
+
+
+    print(series_uid, batch_ndx, bool(malignant_t[0]), malignant_list)
+
+

+ 56 - 56
util/test_affine.py

@@ -133,12 +133,12 @@ def _buildEquivalentTransforms2d(device, input_size, output_size, angle_rad):
 
     transform_ary = intrans_ary @ inscale_ary @ rotation_ary.T @ outscale_ary @ outtrans_ary
     grid_ary = reorder_ary @ rotation_ary.T @ outscale_ary @ outtrans_ary
-    transform_tensor = torch.from_numpy((rotation_ary)).to(device, torch.float32)
+    transform_t = torch.from_numpy((rotation_ary)).to(device, torch.float32)
 
-    transform_tensor = transform_tensor[:2].unsqueeze(0)
+    transform_t = transform_t[:2].unsqueeze(0)
 
-    print('transform_tensor', transform_tensor.size(), transform_tensor.dtype, transform_tensor.device)
-    print(transform_tensor)
+    print('transform_t', transform_t.size(), transform_t.dtype, transform_t.device)
+    print(transform_t)
     print('outtrans_ary', outtrans_ary.shape, outtrans_ary.dtype)
     print(outtrans_ary.round(3))
     print('outscale_ary', outscale_ary.shape, outscale_ary.dtype)
@@ -168,7 +168,7 @@ def _buildEquivalentTransforms2d(device, input_size, output_size, angle_rad):
     prtf([0, 2])
     prtf(output_center[2:])
 
-    return transform_tensor, transform_ary, grid_ary
+    return transform_t, transform_ary, grid_ary
 
 def _buildEquivalentTransforms3d(device, input_size, output_size, angle_rad, axis_vector):
     print("_buildEquivalentTransforms2d", device, input_size, output_size, angle_rad * 180 / math.pi, axis_vector)
@@ -232,11 +232,11 @@ def _buildEquivalentTransforms3d(device, input_size, output_size, angle_rad, axi
 
     transform_ary = intrans_ary @ inscale_ary @ np.linalg.inv(scipyRotation_ary) @ outscale_ary @ outtrans_ary
     grid_ary = reorder_ary @ np.linalg.inv(scipyRotation_ary) @ outscale_ary @ outtrans_ary
-    transform_tensor = torch.from_numpy((torchRotation_ary)).to(device, torch.float32)
-    transform_tensor = transform_tensor[:3].unsqueeze(0)
+    transform_t = torch.from_numpy((torchRotation_ary)).to(device, torch.float32)
+    transform_t = transform_t[:3].unsqueeze(0)
 
-    print('transform_tensor', transform_tensor.size(), transform_tensor.dtype, transform_tensor.device)
-    print(transform_tensor)
+    print('transform_t', transform_t.size(), transform_t.dtype, transform_t.device)
+    print(transform_t)
     print('outtrans_ary', outtrans_ary.shape, outtrans_ary.dtype)
     print(outtrans_ary.round(3))
     print('outscale_ary', outscale_ary.shape, outscale_ary.dtype)
@@ -273,7 +273,7 @@ def _buildEquivalentTransforms3d(device, input_size, output_size, angle_rad, axi
 
     prtf(output_center[2:])
 
-    return transform_tensor, transform_ary, grid_ary
+    return transform_t, transform_ary, grid_ary
 
 
 def test_affine_2d_rotate0(device, affine_func2d):
@@ -282,7 +282,7 @@ def test_affine_2d_rotate0(device, affine_func2d):
     output_size = [1, 1, 5, 5]
     angle_rad = 0.
 
-    transform_tensor, transform_ary, offset = _buildEquivalentTransforms2d(device, input_size, output_size, angle_rad)
+    transform_t, transform_ary, offset = _buildEquivalentTransforms2d(device, input_size, output_size, angle_rad)
 
     # reference
     # https://stackoverflow.com/questions/20161175/how-can-i-use-scipy-ndimage-interpolation-affine-transform-to-rotate-an-image-ab
@@ -302,17 +302,17 @@ def test_affine_2d_rotate0(device, affine_func2d):
     print('scipy_ary', scipy_ary.shape, scipy_ary.dtype)
     print(scipy_ary)
 
-    affine_tensor = affine_func2d(
-            transform_tensor,
+    affine_t = affine_func2d(
+            transform_t,
             torch.Size(output_size)
         )
 
-    print('affine_tensor', affine_tensor.size(), affine_tensor.dtype, affine_tensor.device)
-    print(affine_tensor)
+    print('affine_t', affine_t.size(), affine_t.dtype, affine_t.device)
+    print(affine_t)
 
     gridsample_ary = torch.nn.functional.grid_sample(
             torch.tensor(input_ary, device=device).to(device),
-            affine_tensor,
+            affine_t,
             padding_mode='border'
         ).to('cpu').numpy()
 
@@ -333,7 +333,7 @@ def test_affine_2d_rotate90(device, affine_func2d, input_size2dsq, output_size2d
     output_size = output_size2dsq
     angle_rad = 0.25 * math.pi * 2
 
-    transform_tensor, transform_ary, offset = _buildEquivalentTransforms2d(device, input_size, output_size, angle_rad)
+    transform_t, transform_ary, offset = _buildEquivalentTransforms2d(device, input_size, output_size, angle_rad)
 
     # reference
     # https://stackoverflow.com/questions/20161175/how-can-i-use-scipy-ndimage-interpolation-affine-transform-to-rotate-an-image-ab
@@ -360,17 +360,17 @@ def test_affine_2d_rotate90(device, affine_func2d, input_size2dsq, output_size2d
     assert np.abs(scipy_ary[-1,-1] - input_ary[0,0,-1,0]).max() < 1e-6
     assert np.abs(scipy_ary[-1,0] - input_ary[0,0,0,0]).max() < 1e-6
 
-    affine_tensor = affine_func2d(
-            transform_tensor,
+    affine_t = affine_func2d(
+            transform_t,
             torch.Size(output_size)
         )
 
-    print('affine_tensor', affine_tensor.size(), affine_tensor.dtype, affine_tensor.device)
-    print(affine_tensor)
+    print('affine_t', affine_t.size(), affine_t.dtype, affine_t.device)
+    print(affine_t)
 
     gridsample_ary = torch.nn.functional.grid_sample(
             torch.tensor(input_ary, device=device).to(device),
-            affine_tensor,
+            affine_t,
             padding_mode='border'
         ).to('cpu').numpy()
 
@@ -393,7 +393,7 @@ def test_affine_2d_rotate45(device, affine_func2d):
     output_size = [1, 1, 3, 3]
     angle_rad = 0.125 * math.pi * 2
 
-    transform_tensor, transform_ary, offset = _buildEquivalentTransforms2d(device, input_size, output_size, angle_rad)
+    transform_t, transform_ary, offset = _buildEquivalentTransforms2d(device, input_size, output_size, angle_rad)
 
     # reference
     # https://stackoverflow.com/questions/20161175/how-can-i-use-scipy-ndimage-interpolation-affine-transform-to-rotate-an-image-ab
@@ -413,17 +413,17 @@ def test_affine_2d_rotate45(device, affine_func2d):
     print('scipy_ary', scipy_ary.shape, scipy_ary.dtype)
     print(scipy_ary)
 
-    affine_tensor = affine_func2d(
-            transform_tensor,
+    affine_t = affine_func2d(
+            transform_t,
             torch.Size(output_size)
         )
 
-    print('affine_tensor', affine_tensor.size(), affine_tensor.dtype, affine_tensor.device)
-    print(affine_tensor)
+    print('affine_t', affine_t.size(), affine_t.dtype, affine_t.device)
+    print(affine_t)
 
     gridsample_ary = torch.nn.functional.grid_sample(
             torch.tensor(input_ary, device=device).to(device),
-            affine_tensor,
+            affine_t,
             padding_mode='border'
         ).to('cpu').numpy()
 
@@ -447,7 +447,7 @@ def test_affine_2d_rotateRandom(device, affine_func2d, angle_rad, input_size2d,
     input_ary[0,0,-1,0] = 6
     input_ary[0,0,-1,-1] = 8
 
-    transform_tensor, transform_ary, grid_ary = _buildEquivalentTransforms2d(device, input_size, output_size, angle_rad)
+    transform_t, transform_ary, grid_ary = _buildEquivalentTransforms2d(device, input_size, output_size, angle_rad)
 
     # reference
     # https://stackoverflow.com/questions/20161175/how-can-i-use-scipy-ndimage-interpolation-affine-transform-to-rotate-an-image-ab
@@ -462,22 +462,22 @@ def test_affine_2d_rotateRandom(device, affine_func2d, angle_rad, input_size2d,
         # cval=0.0,
         prefilter=False)
 
-    affine_tensor = affine_func2d(
-            transform_tensor,
+    affine_t = affine_func2d(
+            transform_t,
             torch.Size(output_size)
         )
 
-    print('affine_tensor', affine_tensor.size(), affine_tensor.dtype, affine_tensor.device)
-    print(affine_tensor)
+    print('affine_t', affine_t.size(), affine_t.dtype, affine_t.device)
+    print(affine_t)
 
-    for r in range(affine_tensor.size(1)):
-        for c in range(affine_tensor.size(2)):
+    for r in range(affine_t.size(1)):
+        for c in range(affine_t.size(2)):
             grid_out = grid_ary @ [r, c, 1]
-            print(r, c, 'affine:', affine_tensor[0,r,c], 'grid:', grid_out[:2])
+            print(r, c, 'affine:', affine_t[0,r,c], 'grid:', grid_out[:2])
 
     gridsample_ary = torch.nn.functional.grid_sample(
             torch.tensor(input_ary, device=device).to(device),
-            affine_tensor,
+            affine_t,
             padding_mode='border'
         ).to('cpu').numpy()
 
@@ -488,14 +488,14 @@ def test_affine_2d_rotateRandom(device, affine_func2d, angle_rad, input_size2d,
     print('scipy_ary', scipy_ary.shape, scipy_ary.dtype)
     print(scipy_ary.round(3))
 
-    for r in range(affine_tensor.size(1)):
-        for c in range(affine_tensor.size(2)):
+    for r in range(affine_t.size(1)):
+        for c in range(affine_t.size(2)):
             grid_out = grid_ary @ [r, c, 1]
 
             try:
-                assert np.allclose(affine_tensor[0,r,c], grid_out[:2], atol=1e-5)
+                assert np.allclose(affine_t[0,r,c], grid_out[:2], atol=1e-5)
             except:
-                print(r, c, 'affine:', affine_tensor[0,r,c], 'grid:', grid_out[:2])
+                print(r, c, 'affine:', affine_t[0,r,c], 'grid:', grid_out[:2])
                 raise
 
     assert np.abs(scipy_ary - gridsample_ary).max() < 1e-5
@@ -515,7 +515,7 @@ def test_affine_3d_rotateRandom(device, affine_func3d, angle_rad, axis_vector, i
     input_ary[0,0, -1, -1,  0] = 8
     input_ary[0,0, -1, -1, -1] = 9
 
-    transform_tensor, transform_ary, grid_ary = _buildEquivalentTransforms3d(device, input_size, output_size, angle_rad, axis_vector)
+    transform_t, transform_ary, grid_ary = _buildEquivalentTransforms3d(device, input_size, output_size, angle_rad, axis_vector)
 
     # reference
     # https://stackoverflow.com/questions/20161175/how-can-i-use-scipy-ndimage-interpolation-affine-transform-to-rotate-an-image-ab
@@ -530,26 +530,26 @@ def test_affine_3d_rotateRandom(device, affine_func3d, angle_rad, axis_vector, i
         # cval=0.0,
         prefilter=False)
 
-    affine_tensor = affine_func3d(
-            transform_tensor,
+    affine_t = affine_func3d(
+            transform_t,
             torch.Size(output_size)
         )
 
-    print('affine_tensor', affine_tensor.size(), affine_tensor.dtype, affine_tensor.device)
-    print(affine_tensor)
+    print('affine_t', affine_t.size(), affine_t.dtype, affine_t.device)
+    print(affine_t)
 
-    for i in range(affine_tensor.size(1)):
-        for r in range(affine_tensor.size(2)):
-            for c in range(affine_tensor.size(3)):
+    for i in range(affine_t.size(1)):
+        for r in range(affine_t.size(2)):
+            for c in range(affine_t.size(3)):
                 grid_out = grid_ary @ [i, r, c, 1]
-                print(i, r, c, 'affine:', affine_tensor[0,i,r,c], 'grid:', grid_out[:3].round(3))
+                print(i, r, c, 'affine:', affine_t[0,i,r,c], 'grid:', grid_out[:3].round(3))
 
     print('input_ary', input_ary.shape, input_ary.dtype)
     print(input_ary.round(3))
 
     gridsample_ary = torch.nn.functional.grid_sample(
             torch.tensor(input_ary, device=device).to(device),
-            affine_tensor,
+            affine_t,
             padding_mode='border'
         ).to('cpu').numpy()
 
@@ -558,14 +558,14 @@ def test_affine_3d_rotateRandom(device, affine_func3d, angle_rad, axis_vector, i
     print('scipy_ary', scipy_ary.shape, scipy_ary.dtype)
     print(scipy_ary.round(3))
 
-    for i in range(affine_tensor.size(1)):
-        for r in range(affine_tensor.size(2)):
-            for c in range(affine_tensor.size(3)):
+    for i in range(affine_t.size(1)):
+        for r in range(affine_t.size(2)):
+            for c in range(affine_t.size(3)):
                 grid_out = grid_ary @ [i, r, c, 1]
                 try:
-                    assert np.allclose(affine_tensor[0,i,r,c], grid_out[:3], atol=1e-5)
+                    assert np.allclose(affine_t[0,i,r,c], grid_out[:3], atol=1e-5)
                 except:
-                    print(i, r, c, 'affine:', affine_tensor[0,i,r,c], 'grid:', grid_out[:3].round(3))
+                    print(i, r, c, 'affine:', affine_t[0,i,r,c], 'grid:', grid_out[:3].round(3))
                     raise
 
     assert np.abs(scipy_ary - gridsample_ary).max() < 1e-5

+ 24 - 9
util/util.py

@@ -17,29 +17,37 @@ IrcTuple = collections.namedtuple('IrcTuple', ['index', 'row', 'col'])
 XyzTuple = collections.namedtuple('XyzTuple', ['x', 'y', 'z'])
 
 def xyz2irc(coord_xyz, origin_xyz, vxSize_xyz, direction_tup):
-    # Note: _cri means Col,Row,Index
     if direction_tup == (1, 0, 0, 0, 1, 0, 0, 0, 1):
         direction_ary = np.ones((3,))
     elif direction_tup == (-1, 0, 0, 0, -1, 0, 0, 0, 1):
         direction_ary = np.array((-1, -1, 1))
     else:
-        raise Exception("Unsupported direction_tup: {}".format(direction_tup))
-
-    coord_cri = (np.array(coord_xyz) - np.array(origin_xyz)) / np.array(vxSize_xyz)
+        raise Exception(
+            "Unsupported direction_tup: {}".format(direction_tup),
+        )
+
+    coord_cri = (
+            np.array(coord_xyz)
+            - np.array(origin_xyz)
+        ) / np.array(vxSize_xyz)
     coord_cri *= direction_ary
     return IrcTuple(*list(reversed(coord_cri.tolist())))
 
 def irc2xyz(coord_irc, origin_xyz, vxSize_xyz, direction_tup):
-    # Note: _cri means Col,Row,Index
     coord_cri = np.array(list(reversed(coord_irc)))
     if direction_tup == (1, 0, 0, 0, 1, 0, 0, 0, 1):
         direction_ary = np.ones((3,))
     elif direction_tup == (-1, 0, 0, 0, -1, 0, 0, 0, 1):
         direction_ary = np.array((-1, -1, 1))
     else:
-        raise Exception("Unsupported direction_tup: {}".format(direction_tup))
-
-    coord_xyz = coord_cri * direction_ary * np.array(vxSize_xyz) + np.array(origin_xyz)
+        raise Exception(
+            "Unsupported direction_tup: {}".format(direction_tup),
+        )
+
+    coord_xyz = coord_cri \
+        * direction_ary \
+        * np.array(vxSize_xyz) \
+        + np.array(origin_xyz)
     return XyzTuple(*coord_xyz.tolist())
 
 
@@ -150,7 +158,14 @@ def prhist(ary, prefix_str=None, **kwargs):
 #     print('{:10,}'.format(total_bytes), "total bytes")
 
 
-def enumerateWithEstimate(iter, desc_str, start_ndx=0, print_ndx=4, backoff=2, iter_len=None):
+def enumerateWithEstimate(
+        iter,
+        desc_str,
+        start_ndx=0,
+        print_ndx=4,
+        backoff=2,
+        iter_len=None,
+):
     """
     In terms of behavior, `enumerateWithEstimate` is almost identical
     to the standard `enumerate` (the differences are things like how

Some files were not shown because too many files changed in this diff