浏览代码

Code for first printing

Eli Stevens 5 年之前
父节点
当前提交
fb46b90662

+ 2 - 1
p1ch4/1_image_dog.ipynb

@@ -65,7 +65,8 @@
     "import os\n",
     "import os\n",
     "\n",
     "\n",
     "data_dir = '../data/p1ch4/image-cats/'\n",
     "data_dir = '../data/p1ch4/image-cats/'\n",
-    "filenames = [name for name in os.listdir(data_dir) if os.path.splitext(name)[-1] == '.png']\n",
+    "filenames = [name for name in os.listdir(data_dir)\n",
+    "             if os.path.splitext(name)[-1] == '.png']\n",
     "for i, filename in enumerate(filenames):\n",
     "for i, filename in enumerate(filenames):\n",
     "    img_arr = imageio.imread(os.path.join(data_dir, filename))\n",
     "    img_arr = imageio.imread(os.path.join(data_dir, filename))\n",
     "    img_t = torch.from_numpy(img_arr)\n",
     "    img_t = torch.from_numpy(img_arr)\n",

+ 3 - 4
p1ch4/2_volumetric_ct.ipynb

@@ -22,7 +22,7 @@
      "text": [
      "text": [
       "Reading DICOM (examining files): 1/99 files (1.0%99/99 files (100.0%)\n",
       "Reading DICOM (examining files): 1/99 files (1.0%99/99 files (100.0%)\n",
       "  Found 1 correct series.\n",
       "  Found 1 correct series.\n",
-      "Reading DICOM (loading data): 87/99  (87.999/99  (100.0%)\n"
+      "Reading DICOM (loading data): 31/99  (31.392/99  (92.999/99  (100.0%)\n"
      ]
      ]
     },
     },
     {
     {
@@ -52,7 +52,7 @@
     {
     {
      "data": {
      "data": {
       "text/plain": [
       "text/plain": [
-       "torch.Size([1, 512, 512, 99])"
+       "torch.Size([1, 99, 512, 512])"
       ]
       ]
      },
      },
      "execution_count": 3,
      "execution_count": 3,
@@ -62,7 +62,6 @@
    ],
    ],
    "source": [
    "source": [
     "vol = torch.from_numpy(vol_arr).float()\n",
     "vol = torch.from_numpy(vol_arr).float()\n",
-    "vol = torch.transpose(vol, 0, 2)\n",
     "vol = torch.unsqueeze(vol, 0)\n",
     "vol = torch.unsqueeze(vol, 0)\n",
     "\n",
     "\n",
     "vol.shape"
     "vol.shape"
@@ -120,7 +119,7 @@
    "name": "python",
    "name": "python",
    "nbconvert_exporter": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "pygments_lexer": "ipython3",
-   "version": "3.6.6"
+   "version": "3.7.6"
   }
   }
  },
  },
  "nbformat": 4,
  "nbformat": 4,

+ 8 - 7
p1ch4/3_tabular_wine.ipynb

@@ -8,7 +8,7 @@
    "source": [
    "source": [
     "import numpy as np\n",
     "import numpy as np\n",
     "import torch\n",
     "import torch\n",
-    "torch.set_printoptions(edgeitems=2, precision=2)"
+    "torch.set_printoptions(edgeitems=2, precision=2, linewidth=75)"
    ]
    ]
   },
   },
   {
   {
@@ -36,7 +36,8 @@
    "source": [
    "source": [
     "import csv\n",
     "import csv\n",
     "wine_path = \"../data/p1ch4/tabular-wine/winequality-white.csv\"\n",
     "wine_path = \"../data/p1ch4/tabular-wine/winequality-white.csv\"\n",
-    "wineq_numpy = np.loadtxt(wine_path, dtype=np.float32, delimiter=\";\", skiprows=1)\n",
+    "wineq_numpy = np.loadtxt(wine_path, dtype=np.float32, delimiter=\";\",\n",
+    "                         skiprows=1)\n",
     "wineq_numpy"
     "wineq_numpy"
    ]
    ]
   },
   },
@@ -222,8 +223,8 @@
     {
     {
      "data": {
      "data": {
       "text/plain": [
       "text/plain": [
-       "tensor([6.85e+00, 2.78e-01, 3.34e-01, 6.39e+00, 4.58e-02, 3.53e+01, 1.38e+02,\n",
-       "        9.94e-01, 3.19e+00, 4.90e-01, 1.05e+01])"
+       "tensor([6.85e+00, 2.78e-01, 3.34e-01, 6.39e+00, 4.58e-02, 3.53e+01,\n",
+       "        1.38e+02, 9.94e-01, 3.19e+00, 4.90e-01, 1.05e+01])"
       ]
       ]
      },
      },
      "execution_count": 10,
      "execution_count": 10,
@@ -244,8 +245,8 @@
     {
     {
      "data": {
      "data": {
       "text/plain": [
       "text/plain": [
-       "tensor([7.12e-01, 1.02e-02, 1.46e-02, 2.57e+01, 4.77e-04, 2.89e+02, 1.81e+03,\n",
-       "        8.95e-06, 2.28e-02, 1.30e-02, 1.51e+00])"
+       "tensor([7.12e-01, 1.02e-02, 1.46e-02, 2.57e+01, 4.77e-04, 2.89e+02,\n",
+       "        1.81e+03, 8.95e-06, 2.28e-02, 1.30e-02, 1.51e+00])"
       ]
       ]
      },
      },
      "execution_count": 11,
      "execution_count": 11,
@@ -448,7 +449,7 @@
    "name": "python",
    "name": "python",
    "nbconvert_exporter": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "pygments_lexer": "ipython3",
-   "version": "3.7.5"
+   "version": "3.7.6"
   }
   }
  },
  },
  "nbformat": 4,
  "nbformat": 4,

+ 21 - 15
p1ch4/4_time_series_bikes.ipynb

@@ -8,7 +8,7 @@
    "source": [
    "source": [
     "import numpy as np\n",
     "import numpy as np\n",
     "import torch\n",
     "import torch\n",
-    "torch.set_printoptions(edgeitems=2, threshold=50)"
+    "torch.set_printoptions(edgeitems=2, threshold=50, linewidth=75)"
    ]
    ]
   },
   },
   {
   {
@@ -32,11 +32,12 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "bikes_numpy = np.loadtxt(\"../data/p1ch4/bike-sharing-dataset/hour-fixed.csv\", \n",
-    "                         dtype=np.float32, \n",
-    "                         delimiter=\",\", \n",
-    "                         skiprows=1, \n",
-    "                         converters={1: lambda x: float(x[8:10])}) # <1>\n",
+    "bikes_numpy = np.loadtxt(\n",
+    "    \"../data/p1ch4/bike-sharing-dataset/hour-fixed.csv\", \n",
+    "    dtype=np.float32, \n",
+    "    delimiter=\",\", \n",
+    "    skiprows=1, \n",
+    "    converters={1: lambda x: float(x[8:10])}) # <1>\n",
     "bikes = torch.from_numpy(bikes_numpy)\n",
     "bikes = torch.from_numpy(bikes_numpy)\n",
     "bikes"
     "bikes"
    ]
    ]
@@ -113,7 +114,8 @@
     {
     {
      "data": {
      "data": {
       "text/plain": [
       "text/plain": [
-       "tensor([1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 2, 2, 2, 2])"
+       "tensor([1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 2, 2,\n",
+       "        2, 2])"
       ]
       ]
      },
      },
      "execution_count": 6,
      "execution_count": 6,
@@ -162,9 +164,9 @@
     {
     {
      "data": {
      "data": {
       "text/plain": [
       "text/plain": [
-       "tensor([[ 1.0000,  1.0000,  1.0000,  0.0000,  1.0000,  0.0000,  0.0000,  6.0000,\n",
-       "          0.0000,  1.0000,  0.2400,  0.2879,  0.8100,  0.0000,  3.0000, 13.0000,\n",
-       "         16.0000,  1.0000,  0.0000,  0.0000,  0.0000]])"
+       "tensor([[ 1.0000,  1.0000,  1.0000,  0.0000,  1.0000,  0.0000,  0.0000,\n",
+       "          6.0000,  0.0000,  1.0000,  0.2400,  0.2879,  0.8100,  0.0000,\n",
+       "          3.0000, 13.0000, 16.0000,  1.0000,  0.0000,  0.0000,  0.0000]])"
       ]
       ]
      },
      },
      "execution_count": 8,
      "execution_count": 8,
@@ -193,7 +195,8 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "daily_weather_onehot = torch.zeros(daily_bikes.shape[0], 4, daily_bikes.shape[2])\n",
+    "daily_weather_onehot = torch.zeros(daily_bikes.shape[0], 4,\n",
+    "                                   daily_bikes.shape[2])\n",
     "daily_weather_onehot.shape"
     "daily_weather_onehot.shape"
    ]
    ]
   },
   },
@@ -214,7 +217,8 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "daily_weather_onehot.scatter_(1, daily_bikes[:,9,:].long().unsqueeze(1) - 1, 1.0)\n",
+    "daily_weather_onehot.scatter_(\n",
+    "    1, daily_bikes[:,9,:].long().unsqueeze(1) - 1, 1.0)\n",
     "daily_weather_onehot.shape"
     "daily_weather_onehot.shape"
    ]
    ]
   },
   },
@@ -245,7 +249,8 @@
     "temp = daily_bikes[:, 10, :]\n",
     "temp = daily_bikes[:, 10, :]\n",
     "temp_min = torch.min(temp)\n",
     "temp_min = torch.min(temp)\n",
     "temp_max = torch.max(temp)\n",
     "temp_max = torch.max(temp)\n",
-    "daily_bikes[:, 10, :] = (daily_bikes[:, 10, :] - temp_min) / (temp_max - temp_min)"
+    "daily_bikes[:, 10, :] = ((daily_bikes[:, 10, :] - temp_min)\n",
+    "                         / (temp_max - temp_min))"
    ]
    ]
   },
   },
   {
   {
@@ -255,7 +260,8 @@
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
     "temp = daily_bikes[:, 10, :]\n",
     "temp = daily_bikes[:, 10, :]\n",
-    "daily_bikes[:, 10, :] = (daily_bikes[:, 10, :] - torch.mean(temp)) / torch.std(temp)"
+    "daily_bikes[:, 10, :] = ((daily_bikes[:, 10, :] - torch.mean(temp))\n",
+    "                         / torch.std(temp))"
    ]
    ]
   }
   }
  ],
  ],
@@ -275,7 +281,7 @@
    "name": "python",
    "name": "python",
    "nbconvert_exporter": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "pygments_lexer": "ipython3",
-   "version": "3.6.6"
+   "version": "3.7.6"
   }
   }
  },
  },
  "nbformat": 4,
  "nbformat": 4,

+ 10 - 7
p1ch5/1_parameter_estimation.ipynb

@@ -13,7 +13,7 @@
     "%matplotlib inline\n",
     "%matplotlib inline\n",
     "import numpy as np\n",
     "import numpy as np\n",
     "import torch\n",
     "import torch\n",
-    "torch.set_printoptions(edgeitems=2)"
+    "torch.set_printoptions(edgeitems=2, linewidth=75)"
    ]
    ]
   },
   },
   {
   {
@@ -73,8 +73,8 @@
     {
     {
      "data": {
      "data": {
       "text/plain": [
       "text/plain": [
-       "tensor([35.7000, 55.9000, 58.2000, 81.9000, 56.3000, 48.9000, 33.9000, 21.8000,\n",
-       "        48.4000, 60.4000, 68.4000])"
+       "tensor([35.7000, 55.9000, 58.2000, 81.9000, 56.3000, 48.9000, 33.9000,\n",
+       "        21.8000, 48.4000, 60.4000, 68.4000])"
       ]
       ]
      },
      },
      "execution_count": 5,
      "execution_count": 5,
@@ -128,7 +128,8 @@
      "name": "stdout",
      "name": "stdout",
      "output_type": "stream",
      "output_type": "stream",
      "text": [
      "text": [
-      "shapes: x: torch.Size([]), y: torch.Size([3, 1]), z: torch.Size([1, 3]), a: torch.Size([2, 1, 1])\n",
+      "shapes: x: torch.Size([]), y: torch.Size([3, 1])\n",
+      "        z: torch.Size([1, 3]), a: torch.Size([2, 1, 1])\n",
       "x * y: torch.Size([3, 1])\n",
       "x * y: torch.Size([3, 1])\n",
       "y * z: torch.Size([3, 3])\n",
       "y * z: torch.Size([3, 3])\n",
       "y * z * a: torch.Size([2, 3, 3])\n"
       "y * z * a: torch.Size([2, 3, 3])\n"
@@ -140,7 +141,8 @@
     "y = torch.ones(3,1)\n",
     "y = torch.ones(3,1)\n",
     "z = torch.ones(1,3)\n",
     "z = torch.ones(1,3)\n",
     "a = torch.ones(2, 1, 1)\n",
     "a = torch.ones(2, 1, 1)\n",
-    "print(f\"shapes: x: {x.shape}, y: {y.shape}, z: {z.shape}, a: {a.shape}\")\n",
+    "print(f\"shapes: x: {x.shape}, y: {y.shape}\")\n",
+    "print(f\"        z: {z.shape}, a: {a.shape}\")\n",
     "print(\"x * y:\", (x * y).shape)\n",
     "print(\"x * y:\", (x * y).shape)\n",
     "print(\"y * z:\", (y * z).shape)\n",
     "print(\"y * z:\", (y * z).shape)\n",
     "print(\"y * z * a:\", (y * z * a).shape)"
     "print(\"y * z * a:\", (y * z * a).shape)"
@@ -290,7 +292,8 @@
    },
    },
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
-    "def training_loop(n_epochs, learning_rate, params, t_u, t_c, print_params=True):\n",
+    "def training_loop(n_epochs, learning_rate, params, t_u, t_c,\n",
+    "                  print_params=True):\n",
     "    for epoch in range(1, n_epochs + 1):\n",
     "    for epoch in range(1, n_epochs + 1):\n",
     "        w, b = params\n",
     "        w, b = params\n",
     "\n",
     "\n",
@@ -637,7 +640,7 @@
    "name": "python",
    "name": "python",
    "nbconvert_exporter": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "pygments_lexer": "ipython3",
-   "version": "3.7.5"
+   "version": "3.7.6"
   },
   },
   "pycharm": {
   "pycharm": {
    "stem_cell": {
    "stem_cell": {

+ 5 - 3
p1ch5/2_autograd.ipynb

@@ -18,8 +18,10 @@
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
-    "t_c = torch.tensor([0.5, 14.0, 15.0, 28.0, 11.0, 8.0, 3.0, -4.0, 6.0, 13.0, 21.0])\n",
-    "t_u = torch.tensor([35.7, 55.9, 58.2, 81.9, 56.3, 48.9, 33.9, 21.8, 48.4, 60.4, 68.4])\n",
+    "t_c = torch.tensor([0.5, 14.0, 15.0, 28.0, 11.0, 8.0,\n",
+    "                    3.0, -4.0, 6.0, 13.0, 21.0])\n",
+    "t_u = torch.tensor([35.7, 55.9, 58.2, 81.9, 56.3, 48.9,\n",
+    "                    33.9, 21.8, 48.4, 60.4, 68.4])\n",
     "t_un = 0.1 * t_u"
     "t_un = 0.1 * t_u"
    ]
    ]
   },
   },
@@ -188,7 +190,7 @@
    "name": "python",
    "name": "python",
    "nbconvert_exporter": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "pygments_lexer": "ipython3",
-   "version": "3.6.6"
+   "version": "3.7.6"
   }
   }
  },
  },
  "nbformat": 4,
  "nbformat": 4,

+ 30 - 25
p1ch5/3_optimizers.ipynb

@@ -9,7 +9,7 @@
     "%matplotlib inline\n",
     "%matplotlib inline\n",
     "import numpy as np\n",
     "import numpy as np\n",
     "import torch\n",
     "import torch\n",
-    "torch.set_printoptions(edgeitems=2)"
+    "torch.set_printoptions(edgeitems=2, linewidth=75)"
    ]
    ]
   },
   },
   {
   {
@@ -18,8 +18,10 @@
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
-    "t_c = torch.tensor([0.5, 14.0, 15.0, 28.0, 11.0, 8.0, 3.0, -4.0, 6.0, 13.0, 21.0])\n",
-    "t_u = torch.tensor([35.7, 55.9, 58.2, 81.9, 56.3, 48.9, 33.9, 21.8, 48.4, 60.4, 68.4])\n",
+    "t_c = torch.tensor([0.5, 14.0, 15.0, 28.0, 11.0,\n",
+    "                    8.0, 3.0, -4.0, 6.0, 13.0, 21.0])\n",
+    "t_u = torch.tensor([35.7, 55.9, 58.2, 81.9, 56.3, 48.9,\n",
+    "                    33.9, 21.8, 48.4, 60.4, 68.4])\n",
     "t_un = 0.1 * t_u"
     "t_un = 0.1 * t_u"
    ]
    ]
   },
   },
@@ -56,6 +58,7 @@
        " 'Adadelta',\n",
        " 'Adadelta',\n",
        " 'Adagrad',\n",
        " 'Adagrad',\n",
        " 'Adam',\n",
        " 'Adam',\n",
+       " 'AdamW',\n",
        " 'Adamax',\n",
        " 'Adamax',\n",
        " 'LBFGS',\n",
        " 'LBFGS',\n",
        " 'Optimizer',\n",
        " 'Optimizer',\n",
@@ -184,16 +187,16 @@
      "name": "stdout",
      "name": "stdout",
      "output_type": "stream",
      "output_type": "stream",
      "text": [
      "text": [
-      "Epoch 500, Loss 7.860116\n",
+      "Epoch 500, Loss 7.860118\n",
       "Epoch 1000, Loss 3.828538\n",
       "Epoch 1000, Loss 3.828538\n",
       "Epoch 1500, Loss 3.092191\n",
       "Epoch 1500, Loss 3.092191\n",
       "Epoch 2000, Loss 2.957697\n",
       "Epoch 2000, Loss 2.957697\n",
       "Epoch 2500, Loss 2.933134\n",
       "Epoch 2500, Loss 2.933134\n",
       "Epoch 3000, Loss 2.928648\n",
       "Epoch 3000, Loss 2.928648\n",
       "Epoch 3500, Loss 2.927830\n",
       "Epoch 3500, Loss 2.927830\n",
-      "Epoch 4000, Loss 2.927679\n",
-      "Epoch 4500, Loss 2.927652\n",
-      "Epoch 5000, Loss 2.927647\n"
+      "Epoch 4000, Loss 2.927680\n",
+      "Epoch 4500, Loss 2.927651\n",
+      "Epoch 5000, Loss 2.927648\n"
      ]
      ]
     },
     },
     {
     {
@@ -229,7 +232,7 @@
      "name": "stdout",
      "name": "stdout",
      "output_type": "stream",
      "output_type": "stream",
      "text": [
      "text": [
-      "Epoch 500, Loss 7.612901\n",
+      "Epoch 500, Loss 7.612903\n",
       "Epoch 1000, Loss 3.086700\n",
       "Epoch 1000, Loss 3.086700\n",
       "Epoch 1500, Loss 2.928578\n",
       "Epoch 1500, Loss 2.928578\n",
       "Epoch 2000, Loss 2.927646\n"
       "Epoch 2000, Loss 2.927646\n"
@@ -267,7 +270,7 @@
     {
     {
      "data": {
      "data": {
       "text/plain": [
       "text/plain": [
-       "(tensor([ 8,  0,  3,  6,  4,  1,  2,  5, 10]), tensor([9, 7]))"
+       "(tensor([9, 6, 5, 8, 4, 7, 0, 1, 3]), tensor([ 2, 10]))"
       ]
       ]
      },
      },
      "execution_count": 12,
      "execution_count": 12,
@@ -309,7 +312,8 @@
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
-    "def training_loop(n_epochs, optimizer, params, train_t_u, val_t_u, train_t_c, val_t_c):\n",
+    "def training_loop(n_epochs, optimizer, params, train_t_u, val_t_u,\n",
+    "                  train_t_c, val_t_c):\n",
     "    for epoch in range(1, n_epochs + 1):\n",
     "    for epoch in range(1, n_epochs + 1):\n",
     "        train_t_p = model(train_t_u, *params) # <1>\n",
     "        train_t_p = model(train_t_u, *params) # <1>\n",
     "        train_loss = loss_fn(train_t_p, train_t_c)\n",
     "        train_loss = loss_fn(train_t_p, train_t_c)\n",
@@ -322,8 +326,8 @@
     "        optimizer.step()\n",
     "        optimizer.step()\n",
     "\n",
     "\n",
     "        if epoch <= 3 or epoch % 500 == 0:\n",
     "        if epoch <= 3 or epoch % 500 == 0:\n",
-    "            print('Epoch {}, Training loss {}, Validation loss {}'.format(\n",
-    "                epoch, float(train_loss), float(val_loss)))\n",
+    "            print(f\"Epoch {epoch}, Training loss {train_loss.item():.4f},\"\n",
+    "                  f\" Validation loss {val_loss.item():.4f}\")\n",
     "            \n",
     "            \n",
     "    return params"
     "    return params"
    ]
    ]
@@ -337,21 +341,21 @@
      "name": "stdout",
      "name": "stdout",
      "output_type": "stream",
      "output_type": "stream",
      "text": [
      "text": [
-      "Epoch 1, Training loss 88.59708404541016, Validation loss 43.31699752807617\n",
-      "Epoch 2, Training loss 34.42190933227539, Validation loss 35.03486633300781\n",
-      "Epoch 3, Training loss 27.57990264892578, Validation loss 40.214229583740234\n",
-      "Epoch 500, Training loss 9.516923904418945, Validation loss 9.02982234954834\n",
-      "Epoch 1000, Training loss 4.543173789978027, Validation loss 2.596876621246338\n",
-      "Epoch 1500, Training loss 3.1108808517456055, Validation loss 2.9066450595855713\n",
-      "Epoch 2000, Training loss 2.6984243392944336, Validation loss 4.1561737060546875\n",
-      "Epoch 2500, Training loss 2.579646348953247, Validation loss 5.138668537139893\n",
-      "Epoch 3000, Training loss 2.5454416275024414, Validation loss 5.755766868591309\n"
+      "Epoch 1, Training loss 66.5811, Validation loss 142.3890\n",
+      "Epoch 2, Training loss 38.8626, Validation loss 64.0434\n",
+      "Epoch 3, Training loss 33.3475, Validation loss 39.4590\n",
+      "Epoch 500, Training loss 7.1454, Validation loss 9.1252\n",
+      "Epoch 1000, Training loss 3.5940, Validation loss 5.3110\n",
+      "Epoch 1500, Training loss 3.0942, Validation loss 4.1611\n",
+      "Epoch 2000, Training loss 3.0238, Validation loss 3.7693\n",
+      "Epoch 2500, Training loss 3.0139, Validation loss 3.6279\n",
+      "Epoch 3000, Training loss 3.0125, Validation loss 3.5756\n"
      ]
      ]
     },
     },
     {
     {
      "data": {
      "data": {
       "text/plain": [
       "text/plain": [
-       "tensor([  5.6473, -18.7334], requires_grad=True)"
+       "tensor([  5.1964, -16.7512], requires_grad=True)"
       ]
       ]
      },
      },
      "execution_count": 15,
      "execution_count": 15,
@@ -380,7 +384,8 @@
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
-    "def training_loop(n_epochs, optimizer, params, train_t_u, val_t_u, train_t_c, val_t_c):\n",
+    "def training_loop(n_epochs, optimizer, params, train_t_u, val_t_u,\n",
+    "                  train_t_c, val_t_c):\n",
     "    for epoch in range(1, n_epochs + 1):\n",
     "    for epoch in range(1, n_epochs + 1):\n",
     "        train_t_p = model(train_t_u, *params)\n",
     "        train_t_p = model(train_t_u, *params)\n",
     "        train_loss = loss_fn(train_t_p, train_t_c)\n",
     "        train_loss = loss_fn(train_t_p, train_t_c)\n",
@@ -397,7 +402,7 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 18,
+   "execution_count": 17,
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
@@ -425,7 +430,7 @@
    "name": "python",
    "name": "python",
    "nbconvert_exporter": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "pygments_lexer": "ipython3",
-   "version": "3.6.6"
+   "version": "3.7.6"
   }
   }
  },
  },
  "nbformat": 4,
  "nbformat": 4,

文件差异内容过多而无法显示
+ 55 - 54
p1ch6/1_neural_networks.ipynb


文件差异内容过多而无法显示
+ 17 - 15
p1ch6/2_activation_functions.ipynb


+ 20 - 23
p1ch6/3_nn_module_subclassing.ipynb

@@ -12,7 +12,7 @@
     "import torch.optim as optim\n",
     "import torch.optim as optim\n",
     "import torch.nn as nn\n",
     "import torch.nn as nn\n",
     "\n",
     "\n",
-    "torch.set_printoptions(edgeitems=2)"
+    "torch.set_printoptions(edgeitems=2, linewidth=75)"
    ]
    ]
   },
   },
   {
   {
@@ -117,39 +117,36 @@
   },
   },
   {
   {
    "cell_type": "code",
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 8,
    "metadata": {},
    "metadata": {},
    "outputs": [
    "outputs": [
     {
     {
      "name": "stdout",
      "name": "stdout",
      "output_type": "stream",
      "output_type": "stream",
      "text": [
      "text": [
-      "seq\n",
-      "0.weight              torch.Size([11, 1]) 11\n",
-      "0.bias                torch.Size([11])    11\n",
-      "2.weight              torch.Size([1, 11]) 11\n",
-      "2.bias                torch.Size([1])     1\n",
-      "\n",
-      "namedseq\n",
-      "hidden_linear.weight  torch.Size([12, 1]) 12\n",
-      "hidden_linear.bias    torch.Size([12])    12\n",
-      "output_linear.weight  torch.Size([1, 12]) 12\n",
-      "output_linear.bias    torch.Size([1])     1\n",
-      "\n",
-      "subclass\n",
-      "hidden_linear.weight  torch.Size([13, 1]) 13\n",
-      "hidden_linear.bias    torch.Size([13])    13\n",
-      "output_linear.weight  torch.Size([1, 13]) 13\n",
-      "output_linear.bias    torch.Size([1])     1\n",
-      "\n"
+      "seq\n"
+     ]
+    },
+    {
+     "ename": "TypeError",
+     "evalue": "unsupported format string passed to torch.Size.__format__",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-8-4f1be40ed447>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      4\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype_str\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mname_str\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnamed_parameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m         \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"{name_str:21} {param.shape:19} {param.numel()}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;31mTypeError\u001b[0m: unsupported format string passed to torch.Size.__format__"
      ]
      ]
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "for type_str, model in [('seq', seq_model), ('namedseq', namedseq_model), ('subclass', subclass_model)]:\n",
+    "for type_str, model in [('seq', seq_model),\n",
+    "                        ('namedseq', namedseq_model),\n",
+    "                        ('subclass', subclass_model)]:\n",
     "    print(type_str)\n",
     "    print(type_str)\n",
     "    for name_str, param in model.named_parameters():\n",
     "    for name_str, param in model.named_parameters():\n",
-    "        print(\"{:21} {:19} {}\".format(name_str, str(param.shape), param.numel()))\n",
+    "        print(\"{:21} {:19} {}\".format(\n",
+    "            name_str, str(param.shape), param.numel()))\n",
     "        \n",
     "        \n",
     "    print()"
     "    print()"
    ]
    ]
@@ -210,7 +207,7 @@
    "name": "python",
    "name": "python",
    "nbconvert_exporter": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "pygments_lexer": "ipython3",
-   "version": "3.7.5"
+   "version": "3.7.6"
   }
   }
  },
  },
  "nbformat": 4,
  "nbformat": 4,

文件差异内容过多而无法显示
+ 5 - 26
p1ch7/1_datasets.ipynb


+ 45 - 27
p1ch7/2_birds_airplanes.ipynb

@@ -44,12 +44,13 @@
    "source": [
    "source": [
     "from torchvision import datasets, transforms\n",
     "from torchvision import datasets, transforms\n",
     "data_path = '../data-unversioned/p1ch7/'\n",
     "data_path = '../data-unversioned/p1ch7/'\n",
-    "cifar10 = datasets.CIFAR10(data_path, train=True, download=False,\n",
-    "                          transform=transforms.Compose([\n",
-    "                              transforms.ToTensor(),\n",
-    "                              transforms.Normalize((0.4915, 0.4823, 0.4468),\n",
-    "                                                   (0.2470, 0.2435, 0.2616))\n",
-    "                          ]))"
+    "cifar10 = datasets.CIFAR10(\n",
+    "    data_path, train=True, download=False,\n",
+    "    transform=transforms.Compose([\n",
+    "        transforms.ToTensor(),\n",
+    "        transforms.Normalize((0.4915, 0.4823, 0.4468),\n",
+    "                             (0.2470, 0.2435, 0.2616))\n",
+    "    ]))"
    ]
    ]
   },
   },
   {
   {
@@ -58,12 +59,13 @@
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
-    "cifar10_val = datasets.CIFAR10(data_path, train=False, download=False,\n",
-    "                          transform=transforms.Compose([\n",
-    "                              transforms.ToTensor(),\n",
-    "                              transforms.Normalize((0.4915, 0.4823, 0.4468),\n",
-    "                                                   (0.2470, 0.2435, 0.2616))\n",
-    "                          ]))"
+    "cifar10_val = datasets.CIFAR10(\n",
+    "    data_path, train=False, download=False,\n",
+    "    transform=transforms.Compose([\n",
+    "        transforms.ToTensor(),\n",
+    "        transforms.Normalize((0.4915, 0.4823, 0.4468),\n",
+    "                             (0.2470, 0.2435, 0.2616))\n",
+    "    ]))"
    ]
    ]
   },
   },
   {
   {
@@ -74,8 +76,12 @@
    "source": [
    "source": [
     "label_map = {0: 0, 2: 1}\n",
     "label_map = {0: 0, 2: 1}\n",
     "class_names = ['airplane', 'bird']\n",
     "class_names = ['airplane', 'bird']\n",
-    "cifar2 = [(img, label_map[label]) for img, label in cifar10 if label in [0, 2]]\n",
-    "cifar2_val = [(img, label_map[label]) for img, label in cifar10_val if label in [0, 2]]"
+    "cifar2 = [(img, label_map[label])\n",
+    "          for img, label in cifar10 \n",
+    "          if label in [0, 2]]\n",
+    "cifar2_val = [(img, label_map[label])\n",
+    "              for img, label in cifar10_val\n",
+    "              if label in [0, 2]]"
    ]
    ]
   },
   },
   {
   {
@@ -470,7 +476,8 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "nll_comparison = torch.tensor([neg_log_likelihood(o) for o in [out0, out, out2, out3]])\n",
+    "nll_comparison = torch.tensor([neg_log_likelihood(o) \n",
+    "                               for o in [out0, out, out2, out3]])\n",
     "nll_comparison"
     "nll_comparison"
    ]
    ]
   },
   },
@@ -801,7 +808,8 @@
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
-    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)"
+    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n",
+    "                                           shuffle=True)"
    ]
    ]
   },
   },
   {
   {
@@ -921,7 +929,8 @@
     "import torch.nn as nn\n",
     "import torch.nn as nn\n",
     "import torch.optim as optim\n",
     "import torch.optim as optim\n",
     "\n",
     "\n",
-    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)\n",
+    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n",
+    "                                           shuffle=True)\n",
     "\n",
     "\n",
     "model = nn.Sequential(\n",
     "model = nn.Sequential(\n",
     "            nn.Linear(3072, 128),\n",
     "            nn.Linear(3072, 128),\n",
@@ -1066,7 +1075,8 @@
     "import torch.nn as nn\n",
     "import torch.nn as nn\n",
     "import torch.optim as optim\n",
     "import torch.optim as optim\n",
     "\n",
     "\n",
-    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)\n",
+    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n",
+    "                                           shuffle=True)\n",
     "\n",
     "\n",
     "model = nn.Sequential(\n",
     "model = nn.Sequential(\n",
     "            nn.Linear(3072, 512),\n",
     "            nn.Linear(3072, 512),\n",
@@ -1108,7 +1118,8 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=False)\n",
+    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n",
+    "                                           shuffle=False)\n",
     "\n",
     "\n",
     "correct = 0\n",
     "correct = 0\n",
     "total = 0\n",
     "total = 0\n",
@@ -1137,7 +1148,8 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)\n",
+    "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64,\n",
+    "                                         shuffle=False)\n",
     "\n",
     "\n",
     "correct = 0\n",
     "correct = 0\n",
     "total = 0\n",
     "total = 0\n",
@@ -1304,7 +1316,8 @@
     "import torch.nn as nn\n",
     "import torch.nn as nn\n",
     "import torch.optim as optim\n",
     "import torch.optim as optim\n",
     "\n",
     "\n",
-    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)\n",
+    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n",
+    "                                           shuffle=True)\n",
     "\n",
     "\n",
     "model = nn.Sequential(\n",
     "model = nn.Sequential(\n",
     "            nn.Linear(3072, 1024),\n",
     "            nn.Linear(3072, 1024),\n",
@@ -1349,7 +1362,8 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=False)\n",
+    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n",
+    "                                           shuffle=False)\n",
     "\n",
     "\n",
     "correct = 0\n",
     "correct = 0\n",
     "total = 0\n",
     "total = 0\n",
@@ -1378,7 +1392,8 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)\n",
+    "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64,\n",
+    "                                         shuffle=False)\n",
     "\n",
     "\n",
     "correct = 0\n",
     "correct = 0\n",
     "total = 0\n",
     "total = 0\n",
@@ -1970,7 +1985,8 @@
     "import torch.nn as nn\n",
     "import torch.nn as nn\n",
     "import torch.nn.functional as F\n",
     "import torch.nn.functional as F\n",
     "\n",
     "\n",
-    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)\n",
+    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n",
+    "                                           shuffle=True)\n",
     "\n",
     "\n",
     "class Net(nn.Module):\n",
     "class Net(nn.Module):\n",
     "    def __init__(self):\n",
     "    def __init__(self):\n",
@@ -2016,7 +2032,8 @@
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
-    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=False)\n",
+    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n",
+    "                                           shuffle=False)\n",
     "\n",
     "\n",
     "correct = 0\n",
     "correct = 0\n",
     "total = 0\n",
     "total = 0\n",
@@ -2037,7 +2054,8 @@
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
-    "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)\n",
+    "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64,\n",
+    "                                         shuffle=False)\n",
     "\n",
     "\n",
     "correct = 0\n",
     "correct = 0\n",
     "total = 0\n",
     "total = 0\n",
@@ -2119,7 +2137,7 @@
    "name": "python",
    "name": "python",
    "nbconvert_exporter": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "pygments_lexer": "ipython3",
-   "version": "3.7.5"
+   "version": "3.7.6"
   }
   }
  },
  },
  "nbformat": 4,
  "nbformat": 4,

+ 78 - 40
p1ch8/1_convolution.ipynb

@@ -65,12 +65,13 @@
    "source": [
    "source": [
     "from torchvision import datasets, transforms\n",
     "from torchvision import datasets, transforms\n",
     "data_path = '../data-unversioned/p1ch6/'\n",
     "data_path = '../data-unversioned/p1ch6/'\n",
-    "cifar10 = datasets.CIFAR10(data_path, train=True, download=True,\n",
-    "                          transform=transforms.Compose([\n",
-    "                              transforms.ToTensor(),\n",
-    "                              transforms.Normalize((0.4915, 0.4823, 0.4468),\n",
-    "                                                   (0.2470, 0.2435, 0.2616))\n",
-    "                          ]))"
+    "cifar10 = datasets.CIFAR10(\n",
+    "    data_path, train=True, download=True,\n",
+    "    transform=transforms.Compose([\n",
+    "        transforms.ToTensor(),\n",
+    "        transforms.Normalize((0.4915, 0.4823, 0.4468),\n",
+    "                             (0.2470, 0.2435, 0.2616))\n",
+    "    ]))"
    ]
    ]
   },
   },
   {
   {
@@ -87,12 +88,13 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "cifar10_val = datasets.CIFAR10(data_path, train=False, download=True,\n",
-    "                          transform=transforms.Compose([\n",
-    "                              transforms.ToTensor(),\n",
-    "                              transforms.Normalize((0.4915, 0.4823, 0.4468),\n",
-    "                                                   (0.2470, 0.2435, 0.2616))\n",
-    "                          ]))"
+    "cifar10_val = datasets.CIFAR10(\n",
+    "    data_path, train=False, download=True,\n",
+    "    transform=transforms.Compose([\n",
+    "        transforms.ToTensor(),\n",
+    "        transforms.Normalize((0.4915, 0.4823, 0.4468),\n",
+    "                             (0.2470, 0.2435, 0.2616))\n",
+    "    ]))"
    ]
    ]
   },
   },
   {
   {
@@ -103,8 +105,12 @@
    "source": [
    "source": [
     "label_map = {0: 0, 2: 1}\n",
     "label_map = {0: 0, 2: 1}\n",
     "class_names = ['airplane', 'bird']\n",
     "class_names = ['airplane', 'bird']\n",
-    "cifar2 = [(img, label_map[label]) for img, label in cifar10 if label in [0, 2]]\n",
-    "cifar2_val = [(img, label_map[label]) for img, label in cifar10_val if label in [0, 2]]"
+    "cifar2 = [(img, label_map[label])\n",
+    "          for img, label in cifar10\n",
+    "          if label in [0, 2]]\n",
+    "cifar2_val = [(img, label_map[label])\n",
+    "              for img, label in cifar10_val\n",
+    "              if label in [0, 2]]"
    ]
    ]
   },
   },
   {
   {
@@ -140,7 +146,9 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "numel_list = [p.numel() for p in connected_model.parameters() if p.requires_grad == True]\n",
+    "numel_list = [p.numel()\n",
+    "              for p in connected_model.parameters()\n",
+    "              if p.requires_grad == True]\n",
     "sum(numel_list), numel_list"
     "sum(numel_list), numel_list"
    ]
    ]
   },
   },
@@ -648,18 +656,23 @@
     "    for epoch in range(1, n_epochs + 1):  # <2>\n",
     "    for epoch in range(1, n_epochs + 1):  # <2>\n",
     "        loss_train = 0.0\n",
     "        loss_train = 0.0\n",
     "        for imgs, labels in train_loader:  # <3>\n",
     "        for imgs, labels in train_loader:  # <3>\n",
+    "            \n",
     "            outputs = model(imgs)  # <4>\n",
     "            outputs = model(imgs)  # <4>\n",
+    "            \n",
     "            loss = loss_fn(outputs, labels)  # <5>\n",
     "            loss = loss_fn(outputs, labels)  # <5>\n",
     "\n",
     "\n",
     "            optimizer.zero_grad()  # <6>\n",
     "            optimizer.zero_grad()  # <6>\n",
+    "            \n",
     "            loss.backward()  # <7>\n",
     "            loss.backward()  # <7>\n",
+    "            \n",
     "            optimizer.step()  # <8>\n",
     "            optimizer.step()  # <8>\n",
     "\n",
     "\n",
     "            loss_train += loss.item()  # <9>\n",
     "            loss_train += loss.item()  # <9>\n",
     "\n",
     "\n",
     "        if epoch == 1 or epoch % 10 == 0:\n",
     "        if epoch == 1 or epoch % 10 == 0:\n",
     "            print('{} Epoch {}, Training loss {}'.format(\n",
     "            print('{} Epoch {}, Training loss {}'.format(\n",
-    "                datetime.datetime.now(), epoch, loss_train / len(train_loader)))  # <10>"
+    "                datetime.datetime.now(), epoch,\n",
+    "                loss_train / len(train_loader)))  # <10>"
    ]
    ]
   },
   },
   {
   {
@@ -686,7 +699,8 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)  # <1>\n",
+    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n",
+    "                                           shuffle=True)  # <1>\n",
     "\n",
     "\n",
     "model = Net()  #  <2>\n",
     "model = Net()  #  <2>\n",
     "optimizer = optim.SGD(model.parameters(), lr=1e-2)  #  <3>\n",
     "optimizer = optim.SGD(model.parameters(), lr=1e-2)  #  <3>\n",
@@ -716,8 +730,10 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=False)\n",
-    "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)\n",
+    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n",
+    "                                           shuffle=False)\n",
+    "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64,\n",
+    "                                         shuffle=False)\n",
     "\n",
     "\n",
     "def validate(model, train_loader, val_loader):\n",
     "def validate(model, train_loader, val_loader):\n",
     "    for name, loader in [(\"train\", train_loader), (\"val\", val_loader)]:\n",
     "    for name, loader in [(\"train\", train_loader), (\"val\", val_loader)]:\n",
@@ -763,7 +779,8 @@
    ],
    ],
    "source": [
    "source": [
     "loaded_model = Net()  # <1>\n",
     "loaded_model = Net()  # <1>\n",
-    "loaded_model.load_state_dict(torch.load(data_path + 'birds_vs_airplanes.pt'))"
+    "loaded_model.load_state_dict(torch.load(data_path\n",
+    "                                        + 'birds_vs_airplanes.pt'))"
    ]
    ]
   },
   },
   {
   {
@@ -780,7 +797,8 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
+    "device = (torch.device('cuda') if torch.cuda.is_available()\n",
+    "          else torch.device('cpu'))\n",
     "print(f\"Training on device {device}.\")"
     "print(f\"Training on device {device}.\")"
    ]
    ]
   },
   },
@@ -809,7 +827,8 @@
     "\n",
     "\n",
     "        if epoch == 1 or epoch % 10 == 0:\n",
     "        if epoch == 1 or epoch % 10 == 0:\n",
     "            print('{} Epoch {}, Training loss {}'.format(\n",
     "            print('{} Epoch {}, Training loss {}'.format(\n",
-    "                datetime.datetime.now(), epoch, loss_train / len(train_loader)))"
+    "                datetime.datetime.now(), epoch,\n",
+    "                loss_train / len(train_loader)))"
    ]
    ]
   },
   },
   {
   {
@@ -836,7 +855,8 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)\n",
+    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n",
+    "                                           shuffle=True)\n",
     "\n",
     "\n",
     "model = Net().to(device=device)  # <1>\n",
     "model = Net().to(device=device)  # <1>\n",
     "optimizer = optim.SGD(model.parameters(), lr=1e-2)\n",
     "optimizer = optim.SGD(model.parameters(), lr=1e-2)\n",
@@ -866,8 +886,10 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=False)\n",
-    "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)\n",
+    "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n",
+    "                                           shuffle=False)\n",
+    "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64,\n",
+    "                                         shuffle=False)\n",
     "all_acc_dict = collections.OrderedDict()\n",
     "all_acc_dict = collections.OrderedDict()\n",
     "\n",
     "\n",
     "def validate(model, train_loader, val_loader):\n",
     "def validate(model, train_loader, val_loader):\n",
@@ -910,7 +932,9 @@
    ],
    ],
    "source": [
    "source": [
     "loaded_model = Net().to(device=device)\n",
     "loaded_model = Net().to(device=device)\n",
-    "loaded_model.load_state_dict(torch.load(data_path + 'birds_vs_airplanes.pt', map_location=device))"
+    "loaded_model.load_state_dict(torch.load(data_path\n",
+    "                                        + 'birds_vs_airplanes.pt',\n",
+    "                                        map_location=device))"
    ]
    ]
   },
   },
   {
   {
@@ -998,7 +1022,8 @@
     "        super().__init__()\n",
     "        super().__init__()\n",
     "        self.n_chans1 = n_chans1\n",
     "        self.n_chans1 = n_chans1\n",
     "        self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)\n",
     "        self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)\n",
-    "        self.conv2 = nn.Conv2d(n_chans1, n_chans1 // 2, kernel_size=3, padding=1)\n",
+    "        self.conv2 = nn.Conv2d(n_chans1, n_chans1 // 2, kernel_size=3,\n",
+    "                               padding=1)\n",
     "        self.fc1 = nn.Linear(8 * 8 * n_chans1 // 2, 32)\n",
     "        self.fc1 = nn.Linear(8 * 8 * n_chans1 // 2, 32)\n",
     "        self.fc2 = nn.Linear(32, 2)\n",
     "        self.fc2 = nn.Linear(32, 2)\n",
     "        \n",
     "        \n",
@@ -1078,7 +1103,8 @@
    "metadata": {},
    "metadata": {},
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
-    "def training_loop_l2reg(n_epochs, optimizer, model, loss_fn, train_loader):\n",
+    "def training_loop_l2reg(n_epochs, optimizer, model, loss_fn,\n",
+    "                        train_loader):\n",
     "    for epoch in range(1, n_epochs + 1):\n",
     "    for epoch in range(1, n_epochs + 1):\n",
     "        loss_train = 0.0\n",
     "        loss_train = 0.0\n",
     "        for imgs, labels in train_loader:\n",
     "        for imgs, labels in train_loader:\n",
@@ -1088,7 +1114,8 @@
     "            loss = loss_fn(outputs, labels)\n",
     "            loss = loss_fn(outputs, labels)\n",
     "\n",
     "\n",
     "            l2_lambda = 0.001\n",
     "            l2_lambda = 0.001\n",
-    "            l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())  # <1>\n",
+    "            l2_norm = sum(p.pow(2.0).sum()\n",
+    "                          for p in model.parameters())  # <1>\n",
     "            loss = loss + l2_lambda * l2_norm\n",
     "            loss = loss + l2_lambda * l2_norm\n",
     "\n",
     "\n",
     "            optimizer.zero_grad()\n",
     "            optimizer.zero_grad()\n",
@@ -1098,7 +1125,8 @@
     "            loss_train += loss.item()\n",
     "            loss_train += loss.item()\n",
     "        if epoch == 1 or epoch % 10 == 0:\n",
     "        if epoch == 1 or epoch % 10 == 0:\n",
     "            print('{} Epoch {}, Training loss {}'.format(\n",
     "            print('{} Epoch {}, Training loss {}'.format(\n",
-    "                datetime.datetime.now(), epoch, loss_train / len(train_loader)))\n"
+    "                datetime.datetime.now(), epoch,\n",
+    "                loss_train / len(train_loader)))\n"
    ]
    ]
   },
   },
   {
   {
@@ -1153,7 +1181,8 @@
     "        self.n_chans1 = n_chans1\n",
     "        self.n_chans1 = n_chans1\n",
     "        self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)\n",
     "        self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)\n",
     "        self.conv1_dropout = nn.Dropout2d(p=0.4)\n",
     "        self.conv1_dropout = nn.Dropout2d(p=0.4)\n",
-    "        self.conv2 = nn.Conv2d(n_chans1, n_chans1 // 2, kernel_size=3, padding=1)\n",
+    "        self.conv2 = nn.Conv2d(n_chans1, n_chans1 // 2, kernel_size=3,\n",
+    "                               padding=1)\n",
     "        self.conv2_dropout = nn.Dropout2d(p=0.4)\n",
     "        self.conv2_dropout = nn.Dropout2d(p=0.4)\n",
     "        self.fc1 = nn.Linear(8 * 8 * n_chans1 // 2, 32)\n",
     "        self.fc1 = nn.Linear(8 * 8 * n_chans1 // 2, 32)\n",
     "        self.fc2 = nn.Linear(32, 2)\n",
     "        self.fc2 = nn.Linear(32, 2)\n",
@@ -1221,7 +1250,8 @@
     "        self.n_chans1 = n_chans1\n",
     "        self.n_chans1 = n_chans1\n",
     "        self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)\n",
     "        self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)\n",
     "        self.conv1_batchnorm = nn.BatchNorm2d(num_features=n_chans1)\n",
     "        self.conv1_batchnorm = nn.BatchNorm2d(num_features=n_chans1)\n",
-    "        self.conv2 = nn.Conv2d(n_chans1, n_chans1 // 2, kernel_size=3, padding=1)\n",
+    "        self.conv2 = nn.Conv2d(n_chans1, n_chans1 // 2, kernel_size=3, \n",
+    "                               padding=1)\n",
     "        self.conv2_batchnorm = nn.BatchNorm2d(num_features=n_chans1 // 2)\n",
     "        self.conv2_batchnorm = nn.BatchNorm2d(num_features=n_chans1 // 2)\n",
     "        self.fc1 = nn.Linear(8 * 8 * n_chans1 // 2, 32)\n",
     "        self.fc1 = nn.Linear(8 * 8 * n_chans1 // 2, 32)\n",
     "        self.fc2 = nn.Linear(32, 2)\n",
     "        self.fc2 = nn.Linear(32, 2)\n",
@@ -1288,8 +1318,10 @@
     "        super().__init__()\n",
     "        super().__init__()\n",
     "        self.n_chans1 = n_chans1\n",
     "        self.n_chans1 = n_chans1\n",
     "        self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)\n",
     "        self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)\n",
-    "        self.conv2 = nn.Conv2d(n_chans1, n_chans1 // 2, kernel_size=3, padding=1)\n",
-    "        self.conv3 = nn.Conv2d(n_chans1 // 2, n_chans1 // 2, kernel_size=3, padding=1)\n",
+    "        self.conv2 = nn.Conv2d(n_chans1, n_chans1 // 2, kernel_size=3,\n",
+    "                               padding=1)\n",
+    "        self.conv3 = nn.Conv2d(n_chans1 // 2, n_chans1 // 2,\n",
+    "                               kernel_size=3, padding=1)\n",
     "        self.fc1 = nn.Linear(4 * 4 * n_chans1 // 2, 32)\n",
     "        self.fc1 = nn.Linear(4 * 4 * n_chans1 // 2, 32)\n",
     "        self.fc2 = nn.Linear(32, 2)\n",
     "        self.fc2 = nn.Linear(32, 2)\n",
     "        \n",
     "        \n",
@@ -1354,8 +1386,10 @@
     "        super().__init__()\n",
     "        super().__init__()\n",
     "        self.n_chans1 = n_chans1\n",
     "        self.n_chans1 = n_chans1\n",
     "        self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)\n",
     "        self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)\n",
-    "        self.conv2 = nn.Conv2d(n_chans1, n_chans1 // 2, kernel_size=3, padding=1)\n",
-    "        self.conv3 = nn.Conv2d(n_chans1 // 2, n_chans1 // 2, kernel_size=3, padding=1)\n",
+    "        self.conv2 = nn.Conv2d(n_chans1, n_chans1 // 2, kernel_size=3,\n",
+    "                               padding=1)\n",
+    "        self.conv3 = nn.Conv2d(n_chans1 // 2, n_chans1 // 2,\n",
+    "                               kernel_size=3, padding=1)\n",
     "        self.fc1 = nn.Linear(4 * 4 * n_chans1 // 2, 32)\n",
     "        self.fc1 = nn.Linear(4 * 4 * n_chans1 // 2, 32)\n",
     "        self.fc2 = nn.Linear(32, 2)\n",
     "        self.fc2 = nn.Linear(32, 2)\n",
     "        \n",
     "        \n",
@@ -1419,9 +1453,11 @@
     "class ResBlock(nn.Module):\n",
     "class ResBlock(nn.Module):\n",
     "    def __init__(self, n_chans):\n",
     "    def __init__(self, n_chans):\n",
     "        super(ResBlock, self).__init__()\n",
     "        super(ResBlock, self).__init__()\n",
-    "        self.conv = nn.Conv2d(n_chans, n_chans, kernel_size=3, padding=1, bias=False)  # <1>\n",
+    "        self.conv = nn.Conv2d(n_chans, n_chans, kernel_size=3,\n",
+    "                              padding=1, bias=False)  # <1>\n",
     "        self.batch_norm = nn.BatchNorm2d(num_features=n_chans)\n",
     "        self.batch_norm = nn.BatchNorm2d(num_features=n_chans)\n",
-    "        torch.nn.init.kaiming_normal_(self.conv.weight, nonlinearity='relu')  # <2>\n",
+    "        torch.nn.init.kaiming_normal_(self.conv.weight,\n",
+    "                                      nonlinearity='relu')  # <2>\n",
     "        torch.nn.init.constant_(self.batch_norm.weight, 0.5)\n",
     "        torch.nn.init.constant_(self.batch_norm.weight, 0.5)\n",
     "        torch.nn.init.zeros_(self.batch_norm.bias)\n",
     "        torch.nn.init.zeros_(self.batch_norm.bias)\n",
     "\n",
     "\n",
@@ -1443,7 +1479,8 @@
     "        super().__init__()\n",
     "        super().__init__()\n",
     "        self.n_chans1 = n_chans1\n",
     "        self.n_chans1 = n_chans1\n",
     "        self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)\n",
     "        self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)\n",
-    "        self.resblocks = nn.Sequential(* [ResBlock(n_chans=n_chans1)] * n_blocks)\n",
+    "        self.resblocks = nn.Sequential(\n",
+    "            *(n_blocks * [ResBlock(n_chans=n_chans1)]))\n",
     "        self.fc1 = nn.Linear(8 * 8 * n_chans1, 32)\n",
     "        self.fc1 = nn.Linear(8 * 8 * n_chans1, 32)\n",
     "        self.fc2 = nn.Linear(32, 2)\n",
     "        self.fc2 = nn.Linear(32, 2)\n",
     "        \n",
     "        \n",
@@ -1523,7 +1560,8 @@
     "width =0.3\n",
     "width =0.3\n",
     "plt.bar(np.arange(len(trn_acc)), trn_acc, width=width, label='train')\n",
     "plt.bar(np.arange(len(trn_acc)), trn_acc, width=width, label='train')\n",
     "plt.bar(np.arange(len(val_acc))+ width, val_acc, width=width, label='val')\n",
     "plt.bar(np.arange(len(val_acc))+ width, val_acc, width=width, label='val')\n",
-    "plt.xticks(np.arange(len(val_acc))+ width/2, list(all_acc_dict.keys()), rotation=60)\n",
+    "plt.xticks(np.arange(len(val_acc))+ width/2, list(all_acc_dict.keys()),\n",
+    "           rotation=60)\n",
     "plt.ylabel('accuracy')\n",
     "plt.ylabel('accuracy')\n",
     "plt.legend(loc='lower right')\n",
     "plt.legend(loc='lower right')\n",
     "plt.ylim(0.7, 1)\n",
     "plt.ylim(0.7, 1)\n",

文件差异内容过多而无法显示
+ 53 - 1195
p2_run_everything.ipynb


+ 37 - 12
p2ch10/dsets.py

@@ -22,17 +22,20 @@ log = logging.getLogger(__name__)
 # log.setLevel(logging.INFO)
 # log.setLevel(logging.INFO)
 log.setLevel(logging.DEBUG)
 log.setLevel(logging.DEBUG)
 
 
-raw_cache = getCache('part2ch09_raw')
+raw_cache = getCache('part2ch10_raw')
 
 
-CandidateInfoTuple = namedtuple('CandidateInfoTuple', 'isNodule_bool, diameter_mm, series_uid, center_xyz')
+CandidateInfoTuple = namedtuple(
+    'CandidateInfoTuple',
+    'isNodule_bool, diameter_mm, series_uid, center_xyz',
+)
 
 
 @functools.lru_cache(1)
 @functools.lru_cache(1)
-def getCandidateInfoList(requireDataOnDisk_bool=True):
+def getCandidateInfoList(requireOnDisk_bool=True):
     # We construct a set with all series_uids that are present on disk.
     # We construct a set with all series_uids that are present on disk.
     # This will let us use the data, even if we haven't downloaded all of
     # This will let us use the data, even if we haven't downloaded all of
     # the subsets yet.
     # the subsets yet.
     mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
     mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
-    dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
+    presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
 
 
     diameter_dict = {}
     diameter_dict = {}
     with open('data/part2/luna/annotations.csv', "r") as f:
     with open('data/part2/luna/annotations.csv', "r") as f:
@@ -41,21 +44,24 @@ def getCandidateInfoList(requireDataOnDisk_bool=True):
             annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
             annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
             annotationDiameter_mm = float(row[4])
             annotationDiameter_mm = float(row[4])
 
 
-            diameter_dict.setdefault(series_uid, []).append((annotationCenter_xyz, annotationDiameter_mm))
+            diameter_dict.setdefault(series_uid, []).append(
+                (annotationCenter_xyz, annotationDiameter_mm)
+            )
 
 
     candidateInfo_list = []
     candidateInfo_list = []
     with open('data/part2/luna/candidates.csv', "r") as f:
     with open('data/part2/luna/candidates.csv', "r") as f:
         for row in list(csv.reader(f))[1:]:
         for row in list(csv.reader(f))[1:]:
             series_uid = row[0]
             series_uid = row[0]
 
 
-            if series_uid not in dataPresentOnDisk_set and requireDataOnDisk_bool:
+            if series_uid not in presentOnDisk_set and requireOnDisk_bool:
                 continue
                 continue
 
 
             isNodule_bool = bool(int(row[4]))
             isNodule_bool = bool(int(row[4]))
             candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
             candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
 
 
             candidateDiameter_mm = 0.0
             candidateDiameter_mm = 0.0
-            for annotationCenter_xyz, annotationDiameter_mm in diameter_dict.get(series_uid, []):
+            for annotation_tup in diameter_dict.get(series_uid, []):
+                annotationCenter_xyz, annotationDiameter_mm = annotation_tup
                 for i in range(3):
                 for i in range(3):
                     delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
                     delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
                     if delta_mm > annotationDiameter_mm / 4:
                     if delta_mm > annotationDiameter_mm / 4:
@@ -64,14 +70,21 @@ def getCandidateInfoList(requireDataOnDisk_bool=True):
                     candidateDiameter_mm = annotationDiameter_mm
                     candidateDiameter_mm = annotationDiameter_mm
                     break
                     break
 
 
-            candidateInfo_list.append(CandidateInfoTuple(isNodule_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))
+            candidateInfo_list.append(CandidateInfoTuple(
+                isNodule_bool,
+                candidateDiameter_mm,
+                series_uid,
+                candidateCenter_xyz,
+            ))
 
 
     candidateInfo_list.sort(reverse=True)
     candidateInfo_list.sort(reverse=True)
     return candidateInfo_list
     return candidateInfo_list
 
 
 class Ct:
 class Ct:
     def __init__(self, series_uid):
     def __init__(self, series_uid):
-        mhd_path = glob.glob('data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
+        mhd_path = glob.glob(
+            'data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid)
+        )[0]
 
 
         ct_mhd = sitk.ReadImage(mhd_path)
         ct_mhd = sitk.ReadImage(mhd_path)
         ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
         ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
@@ -90,7 +103,12 @@ class Ct:
         self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)
         self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)
 
 
     def getRawCandidate(self, center_xyz, width_irc):
     def getRawCandidate(self, center_xyz, width_irc):
-        center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_a)
+        center_irc = xyz2irc(
+            center_xyz,
+            self.origin_xyz,
+            self.vxSize_xyz,
+            self.direction_a,
+        )
 
 
         slice_list = []
         slice_list = []
         for axis, center_val in enumerate(center_irc):
         for axis, center_val in enumerate(center_irc):
@@ -137,7 +155,9 @@ class LunaDataset(Dataset):
         self.candidateInfo_list = copy.copy(getCandidateInfoList())
         self.candidateInfo_list = copy.copy(getCandidateInfoList())
 
 
         if series_uid:
         if series_uid:
-            self.candidateInfo_list = [x for x in self.candidateInfo_list if x.series_uid == series_uid]
+            self.candidateInfo_list = [
+                x for x in self.candidateInfo_list if x.series_uid == series_uid
+            ]
 
 
         if isValSet_bool:
         if isValSet_bool:
             assert val_stride > 0, val_stride
             assert val_stride > 0, val_stride
@@ -177,4 +197,9 @@ class LunaDataset(Dataset):
             dtype=torch.long,
             dtype=torch.long,
         )
         )
 
 
-        return candidate_t, pos_t, candidateInfo_tup.series_uid, torch.tensor(center_irc)
+        return (
+            candidate_t,
+            pos_t,
+            candidateInfo_tup.series_uid,
+            torch.tensor(center_irc),
+        )

文件差异内容过多而无法显示
+ 885 - 9
p2ch10_explore_data.ipynb


+ 30 - 10
p2ch11/dsets.py

@@ -25,15 +25,18 @@ log.setLevel(logging.DEBUG)
 
 
 raw_cache = getCache('part2ch11_raw')
 raw_cache = getCache('part2ch11_raw')
 
 
-CandidateInfoTuple = namedtuple('CandidateInfoTuple', 'isNodule_bool, diameter_mm, series_uid, center_xyz')
+CandidateInfoTuple = namedtuple(
+    'CandidateInfoTuple',
+    'isNodule_bool, diameter_mm, series_uid, center_xyz',
+)
 
 
 @functools.lru_cache(1)
 @functools.lru_cache(1)
-def getCandidateInfoList(requireDataOnDisk_bool=True):
+def getCandidateInfoList(requireOnDisk_bool=True):
     # We construct a set with all series_uids that are present on disk.
     # We construct a set with all series_uids that are present on disk.
     # This will let us use the data, even if we haven't downloaded all of
     # This will let us use the data, even if we haven't downloaded all of
     # the subsets yet.
     # the subsets yet.
     mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
     mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
-    dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
+    presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
 
 
     diameter_dict = {}
     diameter_dict = {}
     with open('data/part2/luna/annotations.csv', "r") as f:
     with open('data/part2/luna/annotations.csv', "r") as f:
@@ -42,21 +45,24 @@ def getCandidateInfoList(requireDataOnDisk_bool=True):
             annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
             annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
             annotationDiameter_mm = float(row[4])
             annotationDiameter_mm = float(row[4])
 
 
-            diameter_dict.setdefault(series_uid, []).append((annotationCenter_xyz, annotationDiameter_mm))
+            diameter_dict.setdefault(series_uid, []).append(
+                (annotationCenter_xyz, annotationDiameter_mm),
+            )
 
 
     candidateInfo_list = []
     candidateInfo_list = []
     with open('data/part2/luna/candidates.csv', "r") as f:
     with open('data/part2/luna/candidates.csv', "r") as f:
         for row in list(csv.reader(f))[1:]:
         for row in list(csv.reader(f))[1:]:
             series_uid = row[0]
             series_uid = row[0]
 
 
-            if series_uid not in dataPresentOnDisk_set and requireDataOnDisk_bool:
+            if series_uid not in presentOnDisk_set and requireOnDisk_bool:
                 continue
                 continue
 
 
             isNodule_bool = bool(int(row[4]))
             isNodule_bool = bool(int(row[4]))
             candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
             candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
 
 
             candidateDiameter_mm = 0.0
             candidateDiameter_mm = 0.0
-            for annotationCenter_xyz, annotationDiameter_mm in diameter_dict.get(series_uid, []):
+            for annotation_tup in diameter_dict.get(series_uid, []):
+                annotationCenter_xyz, annotationDiameter_mm = annotation_tup
                 for i in range(3):
                 for i in range(3):
                     delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
                     delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
                     if delta_mm > annotationDiameter_mm / 4:
                     if delta_mm > annotationDiameter_mm / 4:
@@ -65,14 +71,21 @@ def getCandidateInfoList(requireDataOnDisk_bool=True):
                     candidateDiameter_mm = annotationDiameter_mm
                     candidateDiameter_mm = annotationDiameter_mm
                     break
                     break
 
 
-            candidateInfo_list.append(CandidateInfoTuple(isNodule_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))
+            candidateInfo_list.append(CandidateInfoTuple(
+                isNodule_bool,
+                candidateDiameter_mm,
+                series_uid,
+                candidateCenter_xyz,
+            ))
 
 
     candidateInfo_list.sort(reverse=True)
     candidateInfo_list.sort(reverse=True)
     return candidateInfo_list
     return candidateInfo_list
 
 
 class Ct:
 class Ct:
     def __init__(self, series_uid):
     def __init__(self, series_uid):
-        mhd_path = glob.glob('data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
+        mhd_path = glob.glob(
+            'data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid)
+        )[0]
 
 
         ct_mhd = sitk.ReadImage(mhd_path)
         ct_mhd = sitk.ReadImage(mhd_path)
         ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
         ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
@@ -91,7 +104,12 @@ class Ct:
         self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)
         self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)
 
 
     def getRawCandidate(self, center_xyz, width_irc):
     def getRawCandidate(self, center_xyz, width_irc):
-        center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_a)
+        center_irc = xyz2irc(
+            center_xyz,
+            self.origin_xyz,
+            self.vxSize_xyz,
+            self.direction_a,
+        )
 
 
         slice_list = []
         slice_list = []
         for axis, center_val in enumerate(center_irc):
         for axis, center_val in enumerate(center_irc):
@@ -140,7 +158,9 @@ class LunaDataset(Dataset):
         self.candidateInfo_list = copy.copy(getCandidateInfoList())
         self.candidateInfo_list = copy.copy(getCandidateInfoList())
 
 
         if series_uid:
         if series_uid:
-            self.candidateInfo_list = [x for x in self.candidateInfo_list if x.series_uid == series_uid]
+            self.candidateInfo_list = [
+                x for x in self.candidateInfo_list if x.series_uid == series_uid
+            ]
 
 
         if isValSet_bool:
         if isValSet_bool:
             assert val_stride > 0, val_stride
             assert val_stride > 0, val_stride

+ 5 - 7
p2ch11/training.py

@@ -76,7 +76,7 @@ class LunaTrainingApp:
     def initModel(self):
     def initModel(self):
         model = LunaModel()
         model = LunaModel()
         if self.use_cuda:
         if self.use_cuda:
-            log.info("Using CUDA with {} devices.".format(torch.cuda.device_count()))
+            log.info("Using CUDA; {} devices.".format(torch.cuda.device_count()))
             if torch.cuda.device_count() > 1:
             if torch.cuda.device_count() > 1:
                 model = nn.DataParallel(model)
                 model = nn.DataParallel(model)
             model = model.to(self.device)
             model = model.to(self.device)
@@ -140,8 +140,6 @@ class LunaTrainingApp:
         train_dl = self.initTrainDl()
         train_dl = self.initTrainDl()
         val_dl = self.initValDl()
         val_dl = self.initValDl()
 
 
-        self.initTensorboardWriters()
-
         for epoch_ndx in range(1, self.cli_args.epochs + 1):
         for epoch_ndx in range(1, self.cli_args.epochs + 1):
 
 
             log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
             log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
@@ -283,10 +281,10 @@ class LunaTrainingApp:
         metrics_dict['loss/pos'] = \
         metrics_dict['loss/pos'] = \
             metrics_t[METRICS_LOSS_NDX, posLabel_mask].mean()
             metrics_t[METRICS_LOSS_NDX, posLabel_mask].mean()
 
 
-        metrics_dict['correct/all'] = \
-            (pos_correct + neg_correct) / np.float32(metrics_t.shape[1]) * 100
-        metrics_dict['correct/neg'] = (neg_correct) / np.float32(neg_count) * 100
-        metrics_dict['correct/pos'] = (pos_correct) / np.float32(pos_count) * 100
+        metrics_dict['correct/all'] = (pos_correct + neg_correct) \
+            / np.float32(metrics_t.shape[1]) * 100
+        metrics_dict['correct/neg'] = neg_correct / np.float32(neg_count) * 100
+        metrics_dict['correct/pos'] = pos_correct / np.float32(pos_count) * 100
 
 
         log.info(
         log.info(
             ("E{} {:8} {loss/all:.4f} loss, "
             ("E{} {:8} {loss/all:.4f} loss, "

+ 34 - 12
p2ch12/dsets.py

@@ -30,12 +30,12 @@ raw_cache = getCache('part2ch12_raw')
 CandidateInfoTuple = namedtuple('CandidateInfoTuple', 'isNodule_bool, diameter_mm, series_uid, center_xyz')
 CandidateInfoTuple = namedtuple('CandidateInfoTuple', 'isNodule_bool, diameter_mm, series_uid, center_xyz')
 
 
 @functools.lru_cache(1)
 @functools.lru_cache(1)
-def getCandidateInfoList(requireDataOnDisk_bool=True):
+def getCandidateInfoList(requireOnDisk_bool=True):
     # We construct a set with all series_uids that are present on disk.
     # We construct a set with all series_uids that are present on disk.
     # This will let us use the data, even if we haven't downloaded all of
     # This will let us use the data, even if we haven't downloaded all of
     # the subsets yet.
     # the subsets yet.
     mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
     mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
-    dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
+    presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
 
 
     diameter_dict = {}
     diameter_dict = {}
     with open('data/part2/luna/annotations.csv', "r") as f:
     with open('data/part2/luna/annotations.csv', "r") as f:
@@ -44,21 +44,24 @@ def getCandidateInfoList(requireDataOnDisk_bool=True):
             annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
             annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
             annotationDiameter_mm = float(row[4])
             annotationDiameter_mm = float(row[4])
 
 
-            diameter_dict.setdefault(series_uid, []).append((annotationCenter_xyz, annotationDiameter_mm))
+            diameter_dict.setdefault(series_uid, []).append(
+                (annotationCenter_xyz, annotationDiameter_mm),
+            )
 
 
     candidateInfo_list = []
     candidateInfo_list = []
     with open('data/part2/luna/candidates.csv', "r") as f:
     with open('data/part2/luna/candidates.csv', "r") as f:
         for row in list(csv.reader(f))[1:]:
         for row in list(csv.reader(f))[1:]:
             series_uid = row[0]
             series_uid = row[0]
 
 
-            if series_uid not in dataPresentOnDisk_set and requireDataOnDisk_bool:
+            if series_uid not in presentOnDisk_set and requireOnDisk_bool:
                 continue
                 continue
 
 
             isNodule_bool = bool(int(row[4]))
             isNodule_bool = bool(int(row[4]))
             candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
             candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
 
 
             candidateDiameter_mm = 0.0
             candidateDiameter_mm = 0.0
-            for annotationCenter_xyz, annotationDiameter_mm in diameter_dict.get(series_uid, []):
+            for annotation_tup in diameter_dict.get(series_uid, []):
+                annotationCenter_xyz, annotationDiameter_mm = annotation_tup
                 for i in range(3):
                 for i in range(3):
                     delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
                     delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
                     if delta_mm > annotationDiameter_mm / 4:
                     if delta_mm > annotationDiameter_mm / 4:
@@ -67,14 +70,21 @@ def getCandidateInfoList(requireDataOnDisk_bool=True):
                     candidateDiameter_mm = annotationDiameter_mm
                     candidateDiameter_mm = annotationDiameter_mm
                     break
                     break
 
 
-            candidateInfo_list.append(CandidateInfoTuple(isNodule_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))
+            candidateInfo_list.append(CandidateInfoTuple(
+                isNodule_bool,
+                candidateDiameter_mm,
+                series_uid,
+                candidateCenter_xyz,
+            ))
 
 
     candidateInfo_list.sort(reverse=True)
     candidateInfo_list.sort(reverse=True)
     return candidateInfo_list
     return candidateInfo_list
 
 
 class Ct:
 class Ct:
     def __init__(self, series_uid):
     def __init__(self, series_uid):
-        mhd_path = glob.glob('data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
+        mhd_path = glob.glob(
+            'data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid)
+        )[0]
 
 
         ct_mhd = sitk.ReadImage(mhd_path)
         ct_mhd = sitk.ReadImage(mhd_path)
         ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
         ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
@@ -93,7 +103,12 @@ class Ct:
         self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)
         self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)
 
 
     def getRawCandidate(self, center_xyz, width_irc):
     def getRawCandidate(self, center_xyz, width_irc):
-        center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_a)
+        center_irc = xyz2irc(
+            center_xyz,
+            self.origin_xyz,
+            self.vxSize_xyz,
+            self.direction_a,
+        )
 
 
         slice_list = []
         slice_list = []
         for axis, center_val in enumerate(center_irc):
         for axis, center_val in enumerate(center_irc):
@@ -136,7 +151,8 @@ def getCtAugmentedCandidate(
         series_uid, center_xyz, width_irc,
         series_uid, center_xyz, width_irc,
         use_cache=True):
         use_cache=True):
     if use_cache:
     if use_cache:
-        ct_chunk, center_irc = getCtRawCandidate(series_uid, center_xyz, width_irc)
+        ct_chunk, center_irc = \
+            getCtRawCandidate(series_uid, center_xyz, width_irc)
     else:
     else:
         ct = getCt(series_uid)
         ct = getCt(series_uid)
         ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
         ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
@@ -219,7 +235,9 @@ class LunaDataset(Dataset):
             self.use_cache = True
             self.use_cache = True
 
 
         if series_uid:
         if series_uid:
-            self.candidateInfo_list = [x for x in self.candidateInfo_list if x.series_uid == series_uid]
+            self.candidateInfo_list = [
+                x for x in self.candidateInfo_list if x.series_uid == series_uid
+            ]
 
 
         if isValSet_bool:
         if isValSet_bool:
             assert val_stride > 0, val_stride
             assert val_stride > 0, val_stride
@@ -238,8 +256,12 @@ class LunaDataset(Dataset):
         else:
         else:
             raise Exception("Unknown sort: " + repr(sortby_str))
             raise Exception("Unknown sort: " + repr(sortby_str))
 
 
-        self.negative_list = [nt for nt in self.candidateInfo_list if not nt.isNodule_bool]
-        self.pos_list = [nt for nt in self.candidateInfo_list if nt.isNodule_bool]
+        self.negative_list = [
+            nt for nt in self.candidateInfo_list if not nt.isNodule_bool
+        ]
+        self.pos_list = [
+            nt for nt in self.candidateInfo_list if nt.isNodule_bool
+        ]
 
 
         log.info("{!r}: {} {} samples, {} neg, {} pos, {} ratio".format(
         log.info("{!r}: {} {} samples, {} neg, {} pos, {} ratio".format(
             self,
             self,

+ 12 - 4
p2ch12/model.py

@@ -36,13 +36,17 @@ class LunaModel(nn.Module):
                 nn.ConvTranspose2d,
                 nn.ConvTranspose2d,
                 nn.ConvTranspose3d,
                 nn.ConvTranspose3d,
             }:
             }:
-                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out', nonlinearity='relu')
+                nn.init.kaiming_normal_(
+                    m.weight.data, a=0, mode='fan_out', nonlinearity='relu',
+                )
                 if m.bias is not None:
                 if m.bias is not None:
-                    fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
+                    fan_in, fan_out = \
+                        nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                     bound = 1 / math.sqrt(fan_out)
                     bound = 1 / math.sqrt(fan_out)
                     nn.init.normal_(m.bias, -bound, bound)
                     nn.init.normal_(m.bias, -bound, bound)
 
 
 
 
+
     def forward(self, input_batch):
     def forward(self, input_batch):
         bn_output = self.tail_batchnorm(input_batch)
         bn_output = self.tail_batchnorm(input_batch)
 
 
@@ -64,9 +68,13 @@ class LunaBlock(nn.Module):
     def __init__(self, in_channels, conv_channels):
     def __init__(self, in_channels, conv_channels):
         super().__init__()
         super().__init__()
 
 
-        self.conv1 = nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True)
+        self.conv1 = nn.Conv3d(
+            in_channels, conv_channels, kernel_size=3, padding=1, bias=True,
+        )
         self.relu1 = nn.ReLU(inplace=True)
         self.relu1 = nn.ReLU(inplace=True)
-        self.conv2 = nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=True)
+        self.conv2 = nn.Conv3d(
+            conv_channels, conv_channels, kernel_size=3, padding=1, bias=True,
+        )
         self.relu2 = nn.ReLU(inplace=True)
         self.relu2 = nn.ReLU(inplace=True)
 
 
         self.maxpool = nn.MaxPool3d(2, 2)
         self.maxpool = nn.MaxPool3d(2, 2)

+ 1 - 1
p2ch12/training.py

@@ -124,7 +124,7 @@ class LunaTrainingApp:
     def initModel(self):
     def initModel(self):
         model = LunaModel()
         model = LunaModel()
         if self.use_cuda:
         if self.use_cuda:
-            log.info("Using CUDA with {} devices.".format(torch.cuda.device_count()))
+            log.info("Using CUDA; {} devices.".format(torch.cuda.device_count()))
             if torch.cuda.device_count() > 1:
             if torch.cuda.device_count() > 1:
                 model = nn.DataParallel(model)
                 model = nn.DataParallel(model)
             model = model.to(self.device)
             model = model.to(self.device)

+ 1 - 1
p2ch12_explore_data.ipynb

@@ -34,7 +34,7 @@
     "from util.util import xyz2irc\n",
     "from util.util import xyz2irc\n",
     "\n",
     "\n",
     "\n",
     "\n",
-    "candidateInfo_list = getCandidateInfoList(requireDataOnDisk_bool=False)\n",
+    "candidateInfo_list = getCandidateInfoList(requireOnDisk_bool=False)\n",
     "candidateInfo_list[0]"
     "candidateInfo_list[0]"
    ]
    ]
   },
   },

+ 36 - 35
p2ch13/dsets.py

@@ -33,12 +33,12 @@ MaskTuple = namedtuple('MaskTuple', 'raw_dense_mask, dense_mask, body_mask, air_
 CandidateInfoTuple = namedtuple('CandidateInfoTuple', 'isNodule_bool, hasAnnotation_bool, isMal_bool, diameter_mm, series_uid, center_xyz')
 CandidateInfoTuple = namedtuple('CandidateInfoTuple', 'isNodule_bool, hasAnnotation_bool, isMal_bool, diameter_mm, series_uid, center_xyz')
 
 
 @functools.lru_cache(1)
 @functools.lru_cache(1)
-def getCandidateInfoList(requireDataOnDisk_bool=True):
+def getCandidateInfoList(requireOnDisk_bool=True):
     # We construct a set with all series_uids that are present on disk.
     # We construct a set with all series_uids that are present on disk.
     # This will let us use the data, even if we haven't downloaded all of
     # This will let us use the data, even if we haven't downloaded all of
     # the subsets yet.
     # the subsets yet.
     mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
     mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
-    dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
+    presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
 
 
     candidateInfo_list = []
     candidateInfo_list = []
     with open('data/part2/luna/annotations_with_malignancy.csv', "r") as f:
     with open('data/part2/luna/annotations_with_malignancy.csv', "r") as f:
@@ -63,7 +63,7 @@ def getCandidateInfoList(requireDataOnDisk_bool=True):
         for row in list(csv.reader(f))[1:]:
         for row in list(csv.reader(f))[1:]:
             series_uid = row[0]
             series_uid = row[0]
 
 
-            if series_uid not in dataPresentOnDisk_set and requireDataOnDisk_bool:
+            if series_uid not in presentOnDisk_set and requireOnDisk_bool:
                 continue
                 continue
 
 
             isNodule_bool = bool(int(row[4]))
             isNodule_bool = bool(int(row[4]))
@@ -85,18 +85,21 @@ def getCandidateInfoList(requireDataOnDisk_bool=True):
     return candidateInfo_list
     return candidateInfo_list
 
 
 @functools.lru_cache(1)
 @functools.lru_cache(1)
-def getCandidateInfoDict(requireDataOnDisk_bool=True):
-    candidateInfo_list = getCandidateInfoList(requireDataOnDisk_bool)
+def getCandidateInfoDict(requireOnDisk_bool=True):
+    candidateInfo_list = getCandidateInfoList(requireOnDisk_bool)
     candidateInfo_dict = {}
     candidateInfo_dict = {}
 
 
     for candidateInfo_tup in candidateInfo_list:
     for candidateInfo_tup in candidateInfo_list:
-        candidateInfo_dict.setdefault(candidateInfo_tup.series_uid, []).append(candidateInfo_tup)
+        candidateInfo_dict.setdefault(candidateInfo_tup.series_uid,
+                                      []).append(candidateInfo_tup)
 
 
     return candidateInfo_dict
     return candidateInfo_dict
 
 
 class Ct:
 class Ct:
     def __init__(self, series_uid):
     def __init__(self, series_uid):
-        mhd_path = glob.glob('data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
+        mhd_path = glob.glob(
+            'data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid)
+        )[0]
 
 
         ct_mhd = sitk.ReadImage(mhd_path)
         ct_mhd = sitk.ReadImage(mhd_path)
         self.hu_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
         self.hu_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
@@ -117,13 +120,14 @@ class Ct:
             for candidate_tup in candidateInfo_list
             for candidate_tup in candidateInfo_list
             if candidate_tup.isNodule_bool
             if candidate_tup.isNodule_bool
         ]
         ]
-        self.positive_mask, _ = self.buildAnnotationMask(self.positiveInfo_list)
-        self.positive_indexes = sorted(set(self.positive_mask.sum(axis=(1,2)).nonzero()[0]))
+        self.positive_mask = self.buildAnnotationMask(self.positiveInfo_list)
+        self.positive_indexes = (self.positive_mask.sum(axis=(1,2))
+                                 .nonzero()[0].tolist())
 
 
-    def buildAnnotationMask(self, candidateInfo_list, threshold_hu = -700):
+    def buildAnnotationMask(self, positiveInfo_list, threshold_hu = -700):
         boundingBox_a = np.zeros_like(self.hu_a, dtype=np.bool)
         boundingBox_a = np.zeros_like(self.hu_a, dtype=np.bool)
 
 
-        for candidateInfo_tup in candidateInfo_list:
+        for candidateInfo_tup in positiveInfo_list:
             center_irc = xyz2irc(
             center_irc = xyz2irc(
                 candidateInfo_tup.center_xyz,
                 candidateInfo_tup.center_xyz,
                 self.origin_xyz,
                 self.origin_xyz,
@@ -162,19 +166,18 @@ class Ct:
             # assert row_radius > 0
             # assert row_radius > 0
             # assert col_radius > 0
             # assert col_radius > 0
 
 
-            slice_tup = (
-                slice(ci - index_radius, ci + index_radius + 1),
-                slice(cr - row_radius, cr + row_radius + 1),
-                slice(cc - col_radius, cc + row_radius + 1),
-            )
-            boundingBox_a[slice_tup] = True
+            boundingBox_a[
+                 ci - index_radius: ci + index_radius + 1,
+                 cr - row_radius: cr + row_radius + 1,
+                 cc - col_radius: cc + col_radius + 1] = True
 
 
         mask_a = boundingBox_a & (self.hu_a > threshold_hu)
         mask_a = boundingBox_a & (self.hu_a > threshold_hu)
 
 
-        return mask_a, boundingBox_a
+        return mask_a
 
 
     def getRawCandidate(self, center_xyz, width_irc):
     def getRawCandidate(self, center_xyz, width_irc):
-        center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_a)
+        center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz,
+                             self.direction_a)
 
 
         slice_list = []
         slice_list = []
         for axis, center_val in enumerate(center_irc):
         for axis, center_val in enumerate(center_irc):
@@ -209,7 +212,8 @@ def getCt(series_uid):
 @raw_cache.memoize(typed=True)
 @raw_cache.memoize(typed=True)
 def getCtRawCandidate(series_uid, center_xyz, width_irc):
 def getCtRawCandidate(series_uid, center_xyz, width_irc):
     ct = getCt(series_uid)
     ct = getCt(series_uid)
-    ct_chunk, pos_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
+    ct_chunk, pos_chunk, center_irc = ct.getRawCandidate(center_xyz,
+                                                         width_irc)
     ct_chunk.clip(-1000, 1000, ct_chunk)
     ct_chunk.clip(-1000, 1000, ct_chunk)
     return ct_chunk, pos_chunk, center_irc
     return ct_chunk, pos_chunk, center_irc
 
 
@@ -248,24 +252,20 @@ class Luna2dSegmentationDataset(Dataset):
             index_count, positive_indexes = getCtSampleSize(series_uid)
             index_count, positive_indexes = getCtSampleSize(series_uid)
 
 
             if self.fullCt_bool:
             if self.fullCt_bool:
-                self.sample_list.extend([
-                    (series_uid, slice_ndx)
-                    for slice_ndx
-                    in range(index_count)
-                ])
+                self.sample_list += [(series_uid, slice_ndx)
+                                     for slice_ndx in range(index_count)]
             else:
             else:
-                self.sample_list.extend([
-                    (series_uid, slice_ndx)
-                    for slice_ndx
-                    in positive_indexes
-                ])
+                self.sample_list += [(series_uid, slice_ndx)
+                                     for slice_ndx in positive_indexes]
 
 
         self.candidateInfo_list = getCandidateInfoList()
         self.candidateInfo_list = getCandidateInfoList()
 
 
         series_set = set(self.series_list)
         series_set = set(self.series_list)
-        self.candidateInfo_list = [cit for cit in self.candidateInfo_list if cit.series_uid in series_set]
+        self.candidateInfo_list = [cit for cit in self.candidateInfo_list
+                                   if cit.series_uid in series_set]
 
 
-        self.pos_list = [nt for nt in self.candidateInfo_list if nt.isNodule_bool]
+        self.pos_list = [nt for nt in self.candidateInfo_list
+                            if nt.isNodule_bool]
 
 
         log.info("{!r}: {} {} series, {} slices, {} nodules".format(
         log.info("{!r}: {} {} series, {} slices, {} nodules".format(
             self,
             self,
@@ -291,7 +291,6 @@ class Luna2dSegmentationDataset(Dataset):
         for i, context_ndx in enumerate(range(start_ndx, end_ndx)):
         for i, context_ndx in enumerate(range(start_ndx, end_ndx)):
             context_ndx = max(context_ndx, 0)
             context_ndx = max(context_ndx, 0)
             context_ndx = min(context_ndx, ct.hu_a.shape[0] - 1)
             context_ndx = min(context_ndx, ct.hu_a.shape[0] - 1)
-
             ct_t[i] = torch.from_numpy(ct.hu_a[context_ndx].astype(np.float32))
             ct_t[i] = torch.from_numpy(ct.hu_a[context_ndx].astype(np.float32))
 
 
         # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
         # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
@@ -332,8 +331,10 @@ class TrainingLuna2dSegmentationDataset(Luna2dSegmentationDataset):
 
 
         row_offset = random.randrange(0,32)
         row_offset = random.randrange(0,32)
         col_offset = random.randrange(0,32)
         col_offset = random.randrange(0,32)
-        ct_t = torch.from_numpy(ct_a[:, row_offset:row_offset+64, col_offset:col_offset+64]).to(torch.float32)
-        pos_t = torch.from_numpy(pos_a[:, row_offset:row_offset+64, col_offset:col_offset+64]).to(torch.long)
+        ct_t = torch.from_numpy(ct_a[:, row_offset:row_offset+64,
+                                     col_offset:col_offset+64]).to(torch.float32)
+        pos_t = torch.from_numpy(pos_a[:, row_offset:row_offset+64,
+                                       col_offset:col_offset+64]).to(torch.long)
 
 
         slice_ndx = center_irc.index
         slice_ndx = center_irc.index
 
 

+ 19 - 28
p2ch13/model.py

@@ -34,9 +34,12 @@ class UNetWrapper(nn.Module):
         }
         }
         for m in self.modules():
         for m in self.modules():
             if type(m) in init_set:
             if type(m) in init_set:
-                nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu', a=0)
+                nn.init.kaiming_normal_(
+                    m.weight.data, mode='fan_out', nonlinearity='relu', a=0
+                )
                 if m.bias is not None:
                 if m.bias is not None:
-                    fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
+                    fan_in, fan_out = \
+                        nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                     bound = 1 / math.sqrt(fan_out)
                     bound = 1 / math.sqrt(fan_out)
                     nn.init.normal_(m.bias, -bound, bound)
                     nn.init.normal_(m.bias, -bound, bound)
 
 
@@ -48,11 +51,12 @@ class UNetWrapper(nn.Module):
         bn_output = self.input_batchnorm(input_batch)
         bn_output = self.input_batchnorm(input_batch)
         un_output = self.unet(bn_output)
         un_output = self.unet(bn_output)
         fn_output = self.final(un_output)
         fn_output = self.final(un_output)
-
         return fn_output
         return fn_output
 
 
 class SegmentationAugmentation(nn.Module):
 class SegmentationAugmentation(nn.Module):
-    def __init__(self, flip=None, offset=None, scale=None, rotate=None, noise=None):
+    def __init__(
+            self, flip=None, offset=None, scale=None, rotate=None, noise=None
+    ):
         super().__init__()
         super().__init__()
 
 
         self.flip = flip
         self.flip = flip
@@ -63,29 +67,17 @@ class SegmentationAugmentation(nn.Module):
 
 
     def forward(self, input_g, label_g):
     def forward(self, input_g, label_g):
         transform_t = self._build2dTransformMatrix()
         transform_t = self._build2dTransformMatrix()
-
-        # log.debug([input_g.shape, label_g.shape])
-
         transform_t = transform_t.expand(input_g.shape[0], -1, -1)
         transform_t = transform_t.expand(input_g.shape[0], -1, -1)
         transform_t = transform_t.to(input_g.device, torch.float32)
         transform_t = transform_t.to(input_g.device, torch.float32)
-        affine_t = F.affine_grid(
-                transform_t[:,:2],
-                input_g.size(),
-                align_corners=False,
-            )
-
-        augmented_input_g = F.grid_sample(
-                input_g,
-                affine_t,
-                padding_mode='border',
-                align_corners=False,
-            )
-        augmented_label_g = F.grid_sample(
-                label_g.to(torch.float32),
-                affine_t,
-                padding_mode='border',
-                align_corners=False,
-            )
+        affine_t = F.affine_grid(transform_t[:,:2],
+                input_g.size(), align_corners=False)
+
+        augmented_input_g = F.grid_sample(input_g,
+                affine_t, padding_mode='border',
+                align_corners=False)
+        augmented_label_g = F.grid_sample(label_g.to(torch.float32),
+                affine_t, padding_mode='border',
+                align_corners=False)
 
 
         if self.noise:
         if self.noise:
             noise_t = torch.randn_like(augmented_input_g)
             noise_t = torch.randn_like(augmented_input_g)
@@ -96,7 +88,7 @@ class SegmentationAugmentation(nn.Module):
         return augmented_input_g, augmented_label_g > 0.5
         return augmented_input_g, augmented_label_g > 0.5
 
 
     def _build2dTransformMatrix(self):
     def _build2dTransformMatrix(self):
-        transform_t = torch.eye(3).to(torch.float64)
+        transform_t = torch.eye(3)
 
 
         for i in range(2):
         for i in range(2):
             if self.flip:
             if self.flip:
@@ -121,8 +113,7 @@ class SegmentationAugmentation(nn.Module):
             rotation_t = torch.tensor([
             rotation_t = torch.tensor([
                 [c, -s, 0],
                 [c, -s, 0],
                 [s, c, 0],
                 [s, c, 0],
-                [0, 0, 1],
-            ], dtype=torch.float64)
+                [0, 0, 1]])
 
 
             transform_t @= rotation_t
             transform_t @= rotation_t
 
 

+ 23 - 46
p2ch13/training.py

@@ -144,7 +144,7 @@ class SegmentationTrainingApp:
         augmentation_model = SegmentationAugmentation(**self.augmentation_dict)
         augmentation_model = SegmentationAugmentation(**self.augmentation_dict)
 
 
         if self.use_cuda:
         if self.use_cuda:
-            log.info("Using CUDA with {} devices.".format(torch.cuda.device_count()))
+            log.info("Using CUDA; {} devices.".format(torch.cuda.device_count()))
             if torch.cuda.device_count() > 1:
             if torch.cuda.device_count() > 1:
                 segmentation_model = nn.DataParallel(segmentation_model)
                 segmentation_model = nn.DataParallel(segmentation_model)
                 augmentation_model = nn.DataParallel(augmentation_model)
                 augmentation_model = nn.DataParallel(augmentation_model)
@@ -229,6 +229,7 @@ class SegmentationTrainingApp:
             self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)
             self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)
 
 
             if epoch_ndx == 1 or epoch_ndx % self.validation_cadence == 0:
             if epoch_ndx == 1 or epoch_ndx % self.validation_cadence == 0:
+                # if validation is wanted
                 valMetrics_t = self.doValidation(epoch_ndx, val_dl)
                 valMetrics_t = self.doValidation(epoch_ndx, val_dl)
                 score = self.logMetrics(epoch_ndx, 'val', valMetrics_t)
                 score = self.logMetrics(epoch_ndx, 'val', valMetrics_t)
                 best_score = max(score, best_score)
                 best_score = max(score, best_score)
@@ -278,7 +279,8 @@ class SegmentationTrainingApp:
 
 
         return valMetrics_g.to('cpu')
         return valMetrics_g.to('cpu')
 
 
-    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g, classificationThreshold=0.5):
+    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g,
+                         classificationThreshold=0.5):
         input_t, label_t, series_list, _slice_ndx_list = batch_tup
         input_t, label_t, series_list, _slice_ndx_list = batch_tup
 
 
         input_g = input_t.to(self.device, non_blocking=True)
         input_g = input_t.to(self.device, non_blocking=True)
@@ -296,30 +298,21 @@ class SegmentationTrainingApp:
         end_ndx = start_ndx + input_t.size(0)
         end_ndx = start_ndx + input_t.size(0)
 
 
         with torch.no_grad():
         with torch.no_grad():
-            predictionBool_g = \
-                (prediction_g[:, 0:1] > classificationThreshold).to(torch.float32)
+            predictionBool_g = (prediction_g[:, 0:1]
+                                > classificationThreshold).to(torch.float32)
 
 
-            # metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_list
-            metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_g
-            # metrics_g[METRICS_FN_LOSS_NDX, start_ndx:end_ndx] = fnLoss_g
-
-            intersectionSum = lambda a, b: (a * b).sum(dim=[1,2,3])
-
-            tp = intersectionSum(    predictionBool_g,  label_g)
-            fn = intersectionSum(1 - predictionBool_g,  label_g)
-            fp = intersectionSum(    predictionBool_g, ~label_g)
+            tp = (     predictionBool_g *  label_g).sum(dim=[1,2,3])
+            fn = ((1 - predictionBool_g) *  label_g).sum(dim=[1,2,3])
+            fp = (     predictionBool_g * (~label_g)).sum(dim=[1,2,3])
 
 
+            metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_g
             metrics_g[METRICS_TP_NDX, start_ndx:end_ndx] = tp
             metrics_g[METRICS_TP_NDX, start_ndx:end_ndx] = tp
             metrics_g[METRICS_FN_NDX, start_ndx:end_ndx] = fn
             metrics_g[METRICS_FN_NDX, start_ndx:end_ndx] = fn
             metrics_g[METRICS_FP_NDX, start_ndx:end_ndx] = fp
             metrics_g[METRICS_FP_NDX, start_ndx:end_ndx] = fp
 
 
-            del tp, fn, fp
-
-        return diceLoss_g.mean() + fnLoss_g.mean() * 2**3# / 2**1
-
-    def diceLoss(self, prediction_g, label_g, epsilon=1, p=False):
-        # log.debug([prediction_g.shape, label_g.shape])
+        return diceLoss_g.mean() + fnLoss_g.mean() * 8
 
 
+    def diceLoss(self, prediction_g, label_g, epsilon=1):
         diceLabel_g = label_g.sum(dim=[1,2,3])
         diceLabel_g = label_g.sum(dim=[1,2,3])
         dicePrediction_g = prediction_g.sum(dim=[1,2,3])
         dicePrediction_g = prediction_g.sum(dim=[1,2,3])
         diceCorrect_g = (prediction_g * label_g).sum(dim=[1,2,3])
         diceCorrect_g = (prediction_g * label_g).sum(dim=[1,2,3])
@@ -333,13 +326,12 @@ class SegmentationTrainingApp:
     def logImages(self, epoch_ndx, mode_str, dl):
     def logImages(self, epoch_ndx, mode_str, dl):
         self.segmentation_model.eval()
         self.segmentation_model.eval()
 
 
-        images_iter = sorted(dl.dataset.series_list)[:12]
-        for series_ndx, series_uid in enumerate(images_iter):
+        images = sorted(dl.dataset.series_list)[:12]
+        for series_ndx, series_uid in enumerate(images):
             ct = getCt(series_uid)
             ct = getCt(series_uid)
 
 
             for slice_ndx in range(6):
             for slice_ndx in range(6):
-                ct_ndx = slice_ndx * ct.hu_a.shape[0] // 5
-                ct_ndx = min(ct_ndx, ct.hu_a.shape[0] - 1)
+                ct_ndx = slice_ndx * (ct.hu_a.shape[0] - 1) // 5
                 sample_tup = dl.dataset.getitem_fullSlice(series_uid, ct_ndx)
                 sample_tup = dl.dataset.getitem_fullSlice(series_uid, ct_ndx)
 
 
                 ct_t, label_t, series_uid, ct_ndx = sample_tup
                 ct_t, label_t, series_uid, ct_ndx = sample_tup
@@ -351,9 +343,8 @@ class SegmentationTrainingApp:
                 prediction_a = prediction_g.to('cpu').detach().numpy()[0] > 0.5
                 prediction_a = prediction_g.to('cpu').detach().numpy()[0] > 0.5
                 label_a = label_g.cpu().numpy()[0][0] > 0.5
                 label_a = label_g.cpu().numpy()[0][0] > 0.5
 
 
-                ct_t[:-1,:,:] /= 1000
-                ct_t[:-1,:,:] += 1
-                ct_t[:-1,:,:] /= 2
+                ct_t[:-1,:,:] /= 2000
+                ct_t[:-1,:,:] += 0.5
 
 
                 ctSlice_a = ct_t[dl.dataset.contextSlices_count].numpy()
                 ctSlice_a = ct_t[dl.dataset.contextSlices_count].numpy()
 
 
@@ -369,11 +360,7 @@ class SegmentationTrainingApp:
 
 
                 writer = getattr(self, mode_str + '_writer')
                 writer = getattr(self, mode_str + '_writer')
                 writer.add_image(
                 writer.add_image(
-                    '{}/{}_prediction_{}'.format(
-                        mode_str,
-                        series_ndx,
-                        slice_ndx,
-                    ),
+                    f'{mode_str}/{series_ndx}_prediction_{slice_ndx}',
                     image_a,
                     image_a,
                     self.totalTrainingSamples_count,
                     self.totalTrainingSamples_count,
                     dataformats='HWC',
                     dataformats='HWC',
@@ -399,13 +386,11 @@ class SegmentationTrainingApp:
                         self.totalTrainingSamples_count,
                         self.totalTrainingSamples_count,
                         dataformats='HWC',
                         dataformats='HWC',
                     )
                     )
+                # This flush prevents TB from getting confused about which
+                # data item belongs where.
                 writer.flush()
                 writer.flush()
 
 
-    def logMetrics(self,
-        epoch_ndx,
-        mode_str,
-        metrics_t,
-    ):
+    def logMetrics(self, epoch_ndx, mode_str, metrics_t):
         log.info("E{} {}".format(
         log.info("E{} {}".format(
             epoch_ndx,
             epoch_ndx,
             type(self).__name__,
             type(self).__name__,
@@ -528,17 +513,9 @@ class SegmentationTrainingApp:
 
 
         if isBest:
         if isBest:
             best_path = os.path.join(
             best_path = os.path.join(
-                'data-unversioned',
-                'part2',
-                'models',
+                'data-unversioned', 'part2', 'models',
                 self.cli_args.tb_prefix,
                 self.cli_args.tb_prefix,
-                '{}_{}_{}.{}.state'.format(
-                    type_str,
-                    self.time_str,
-                    self.cli_args.comment,
-                    'best',
-                )
-            )
+                f'{type_str}_{self.time_str}_{self.cli_args.comment}.best.state')
             shutil.copyfile(file_path, best_path)
             shutil.copyfile(file_path, best_path)
 
 
             log.info("Saved model params to {}".format(best_path))
             log.info("Saved model params to {}".format(best_path))

+ 1 - 1
p2ch13_explore_data.ipynb

@@ -36,7 +36,7 @@
     "from util.util import xyz2irc\n",
     "from util.util import xyz2irc\n",
     "\n",
     "\n",
     "\n",
     "\n",
-    "candidateInfo_list = getCandidateInfoList(requireDataOnDisk_bool=False)\n",
+    "candidateInfo_list = getCandidateInfoList(requireOnDisk_bool=False)\n",
     "candidateInfo_list[0]"
     "candidateInfo_list[0]"
    ]
    ]
   },
   },

+ 1 - 1
p2ch13_explore_data_v2.ipynb

@@ -36,7 +36,7 @@
     "from util.util import xyz2irc\n",
     "from util.util import xyz2irc\n",
     "\n",
     "\n",
     "\n",
     "\n",
-    "candidateInfo_list = getCandidateInfoList(requireDataOnDisk_bool=False)\n",
+    "candidateInfo_list = getCandidateInfoList(requireOnDisk_bool=False)\n",
     "candidateInfo_list[0]"
     "candidateInfo_list[0]"
    ]
    ]
   },
   },

+ 0 - 1
p2ch14/check_nodule_fp_rate.py

@@ -418,7 +418,6 @@ class FalsePosRateCheckApp:
             index=list(range(1, candidate_count+1)),
             index=list(range(1, candidate_count+1)),
         )
         )
 
 
-
         candidateInfo_list = []
         candidateInfo_list = []
         for i, center_irc in enumerate(centerIrc_list):
         for i, center_irc in enumerate(centerIrc_list):
             assert np.isfinite(center_irc).all(), repr([series_uid, i, candidate_count, (ct.hu_a[candidateLabel_a == i+1]).sum(), center_irc])
             assert np.isfinite(center_irc).all(), repr([series_uid, i, candidate_count, (ct.hu_a[candidateLabel_a == i+1]).sum(), center_irc])

+ 38 - 52
p2ch14/dsets.py

@@ -27,16 +27,22 @@ log.setLevel(logging.DEBUG)
 
 
 raw_cache = getCache('part2ch14_raw')
 raw_cache = getCache('part2ch14_raw')
 
 
-CandidateInfoTuple = namedtuple('CandidateInfoTuple', 'isNodule_bool, hasAnnotation_bool, isMal_bool, diameter_mm, series_uid, center_xyz')
-MaskTuple = namedtuple('MaskTuple', 'raw_dense_mask, dense_mask, body_mask, air_mask, raw_candidate_mask, candidate_mask, lung_mask, neg_mask, pos_mask')
+CandidateInfoTuple = namedtuple(
+    'CandidateInfoTuple',
+    'isNodule_bool, hasAnnotation_bool, isMal_bool, diameter_mm, series_uid, center_xyz',
+)
+MaskTuple = namedtuple(
+    'MaskTuple',
+    'raw_dense_mask, dense_mask, body_mask, air_mask, raw_candidate_mask, candidate_mask, lung_mask, neg_mask, pos_mask',
+)
 
 
 @functools.lru_cache(1)
 @functools.lru_cache(1)
-def getCandidateInfoList(requireDataOnDisk_bool=True):
+def getCandidateInfoList(requireOnDisk_bool=True):
     # We construct a set with all series_uids that are present on disk.
     # We construct a set with all series_uids that are present on disk.
     # This will let us use the data, even if we haven't downloaded all of
     # This will let us use the data, even if we haven't downloaded all of
     # the subsets yet.
     # the subsets yet.
     mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
     mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
-    dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
+    presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
 
 
     candidateInfo_list = []
     candidateInfo_list = []
     with open('data/part2/luna/annotations_with_malignancy.csv', "r") as f:
     with open('data/part2/luna/annotations_with_malignancy.csv', "r") as f:
@@ -52,21 +58,28 @@ def getCandidateInfoList(requireDataOnDisk_bool=True):
         for row in list(csv.reader(f))[1:]:
         for row in list(csv.reader(f))[1:]:
             series_uid = row[0]
             series_uid = row[0]
 
 
-            if series_uid not in dataPresentOnDisk_set and requireDataOnDisk_bool:
+            if series_uid not in presentOnDisk_set and requireOnDisk_bool:
                 continue
                 continue
 
 
             isNodule_bool = bool(int(row[4]))
             isNodule_bool = bool(int(row[4]))
             candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
             candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
 
 
             if not isNodule_bool:
             if not isNodule_bool:
-                candidateInfo_list.append(CandidateInfoTuple(False, False, False, 0.0, series_uid, candidateCenter_xyz))
+                candidateInfo_list.append(CandidateInfoTuple(
+                    False,
+                    False,
+                    False,
+                    0.0,
+                    series_uid,
+                    candidateCenter_xyz,
+                ))
 
 
     candidateInfo_list.sort(reverse=True)
     candidateInfo_list.sort(reverse=True)
     return candidateInfo_list
     return candidateInfo_list
 
 
 @functools.lru_cache(1)
 @functools.lru_cache(1)
-def getCandidateInfoDict(requireDataOnDisk_bool=True):
-    candidateInfo_list = getCandidateInfoList(requireDataOnDisk_bool)
+def getCandidateInfoDict(requireOnDisk_bool=True):
+    candidateInfo_list = getCandidateInfoList(requireOnDisk_bool)
     candidateInfo_dict = {}
     candidateInfo_dict = {}
 
 
     for candidateInfo_tup in candidateInfo_list:
     for candidateInfo_tup in candidateInfo_list:
@@ -77,7 +90,9 @@ def getCandidateInfoDict(requireDataOnDisk_bool=True):
 
 
 class Ct:
 class Ct:
     def __init__(self, series_uid):
     def __init__(self, series_uid):
-        mhd_path = glob.glob('data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
+        mhd_path = glob.glob(
+            'data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid)
+        )[0]
 
 
         ct_mhd = sitk.ReadImage(mhd_path)
         ct_mhd = sitk.ReadImage(mhd_path)
         ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
         ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
@@ -251,10 +266,14 @@ class LunaDataset(Dataset):
         else:
         else:
             raise Exception("Unknown sort: " + repr(sortby_str))
             raise Exception("Unknown sort: " + repr(sortby_str))
 
 
-        self.neg_list = [nt for nt in self.candidateInfo_list if not nt.isNodule_bool]
-        self.pos_list = [nt for nt in self.candidateInfo_list if nt.isNodule_bool]
-        self.ben_list = [nt for nt in self.pos_list if not nt.isMal_bool]
-        self.mal_list = [nt for nt in self.pos_list if nt.isMal_bool]
+        self.neg_list = \
+            [nt for nt in self.candidateInfo_list if not nt.isNodule_bool]
+        self.pos_list = \
+            [nt for nt in self.candidateInfo_list if nt.isNodule_bool]
+        self.ben_list = \
+            [nt for nt in self.pos_list if not nt.isMal_bool]
+        self.mal_list = \
+            [nt for nt in self.pos_list if nt.isMal_bool]
 
 
         log.info("{!r}: {} {} samples, {} neg, {} pos, {} ratio".format(
         log.info("{!r}: {} {} samples, {} neg, {} pos, {} ratio".format(
             self,
             self,
@@ -276,8 +295,6 @@ class LunaDataset(Dataset):
     def __len__(self):
     def __len__(self):
         if self.ratio_int:
         if self.ratio_int:
             return 50000
             return 50000
-            # return 50000
-            # return 200000
         else:
         else:
             return len(self.candidateInfo_list)
             return len(self.candidateInfo_list)
 
 
@@ -295,7 +312,9 @@ class LunaDataset(Dataset):
         else:
         else:
             candidateInfo_tup = self.candidateInfo_list[ndx]
             candidateInfo_tup = self.candidateInfo_list[ndx]
 
 
-        return self.sampleFromCandidateInfo_tup(candidateInfo_tup, candidateInfo_tup.isNodule_bool)
+        return self.sampleFromCandidateInfo_tup(
+            candidateInfo_tup, candidateInfo_tup.isNodule_bool
+        )
 
 
     def sampleFromCandidateInfo_tup(self, candidateInfo_tup, label_bool):
     def sampleFromCandidateInfo_tup(self, candidateInfo_tup, label_bool):
         width_irc = (32, 48, 48)
         width_irc = (32, 48, 48)
@@ -337,45 +356,10 @@ class LunaDataset(Dataset):
         return candidate_t, label_t, index_t, candidateInfo_tup.series_uid, torch.tensor(center_irc)
         return candidate_t, label_t, index_t, candidateInfo_tup.series_uid, torch.tensor(center_irc)
 
 
 
 
-# class MalignantLunaDataset(LunaDataset):
-# # tag::ds_balancing_len[]
-#     def __len__(self):
-#         if self.ratio_int:
-#             return 10000
-#             # return 50000
-#             # return 200000
-#         else:
-#             return len(self.ben_list + self.mal_list)
-# # end::ds_balancing_len[]
-#
-# # tag::ds_balancing_getitem[]
-#     def __getitem__(self, ndx):
-#         if self.ratio_int:
-#             mal_ndx = ndx // (self.ratio_int + 1)
-#
-#             if ndx % (self.ratio_int + 1):
-#                 ben_ndx = ndx - 1 - mal_ndx
-#                 ben_ndx %= len(self.ben_list)
-#                 candidateInfo_tup = self.ben_list[ben_ndx]
-#             else:
-#                 mal_ndx %= len(self.mal_list)
-#                 candidateInfo_tup = self.mal_list[mal_ndx]
-#         else:
-#             if ndx >= len(self.ben_list):
-#                 candidateInfo_tup = self.mal_list[ndx - len(self.ben_list)]
-#             else:
-#                 candidateInfo_tup = self.ben_list[ndx]
-#
-#         return self.sampleFromCandidateInfo_tup(candidateInfo_tup, candidateInfo_tup.isMal_bool)
-# # end::ds_balancing_getitem[]
-
 class MalignantLunaDataset(LunaDataset):
 class MalignantLunaDataset(LunaDataset):
     def __len__(self):
     def __len__(self):
         if self.ratio_int:
         if self.ratio_int:
-            # return 10000
             return 100000
             return 100000
-            # return 50000
-            # return 200000
         else:
         else:
             return len(self.ben_list + self.mal_list)
             return len(self.ben_list + self.mal_list)
 
 
@@ -393,4 +377,6 @@ class MalignantLunaDataset(LunaDataset):
             else:
             else:
                 candidateInfo_tup = self.ben_list[ndx]
                 candidateInfo_tup = self.ben_list[ndx]
 
 
-        return self.sampleFromCandidateInfo_tup(candidateInfo_tup, candidateInfo_tup.isMal_bool)
+        return self.sampleFromCandidateInfo_tup(
+            candidateInfo_tup, candidateInfo_tup.isMal_bool
+        )

+ 11 - 47
p2ch14/model.py

@@ -84,9 +84,12 @@ class LunaModel(nn.Module):
                 nn.ConvTranspose2d,
                 nn.ConvTranspose2d,
                 nn.ConvTranspose3d,
                 nn.ConvTranspose3d,
             }:
             }:
-                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out', nonlinearity='relu')
+                nn.init.kaiming_normal_(
+                    m.weight.data, a=0, mode='fan_out', nonlinearity='relu'
+                )
                 if m.bias is not None:
                 if m.bias is not None:
-                    fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
+                    fan_in, fan_out = \
+                        nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                     bound = 1 / math.sqrt(fan_out)
                     bound = 1 / math.sqrt(fan_out)
                     nn.init.normal_(m.bias, -bound, bound)
                     nn.init.normal_(m.bias, -bound, bound)
 
 
@@ -107,57 +110,18 @@ class LunaModel(nn.Module):
 
 
         return linear_output, self.head_activation(linear_output)
         return linear_output, self.head_activation(linear_output)
 
 
-class ModifiedLunaModel(nn.Sequential):
-    def __init__(self, in_channels=1, conv_channels=32):
-        super().__init__(
-            nn.BatchNorm3d(1),
-            nn.Conv3d(in_channels, conv_channels, (1, 5, 5), padding=(0, 2, 2)),
-            nn.ReLU(),
-            nn.MaxPool3d(2),
-            nn.Conv3d(conv_channels, 2 * conv_channels, (1, 5, 5), padding=(0, 2, 2)),
-            nn.ReLU(),
-            nn.BatchNorm3d(2 * conv_channels),
-            nn.MaxPool3d(2),
-            nn.Conv3d(2 * conv_channels, 4 * conv_channels, (1, 3, 3), padding=(0, 1, 1)),
-            nn.ReLU(),
-            nn.MaxPool3d(2),
-            nn.Conv3d(4 * conv_channels, 8 * conv_channels, (1, 3, 3), padding=(0, 1, 1)),
-            nn.ReLU(),
-            nn.MaxPool3d(2),
-            nn.Conv3d(8 * conv_channels, 16 * conv_channels, (1, 3, 3), padding=(0, 1, 1)),
-            nn.ReLU(),
-            nn.Flatten(),
-            nn.Linear(18 * 16 * conv_channels, 512),
-            nn.ReLU(),
-            nn.Linear(512, 256),
-            nn.ReLU(),
-            nn.Linear(256, 2)
-        )
-        self._init_weights()
-
-    def forward(self, x):
-        x = super().forward(x)
-        return x, nn.functional.softmax(x, 1)
-
-    def _init_weights(self):
-        for m in self.modules():
-            if isinstance(m, nn.Conv3d):
-                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
-                if m.bias is not None:
-                    nn.init.zeros_(m.bias)
-            elif isinstance(m, nn.Linear):
-                nn.init.kaiming_normal_(m.weight)
-                if m.bias is not None:
-                    nn.init.zeros_(m.bias)
-
 
 
 class LunaBlock(nn.Module):
 class LunaBlock(nn.Module):
     def __init__(self, in_channels, conv_channels):
     def __init__(self, in_channels, conv_channels):
         super().__init__()
         super().__init__()
 
 
-        self.conv1 = nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True)
+        self.conv1 = nn.Conv3d(
+            in_channels, conv_channels, kernel_size=3, padding=1, bias=True
+        )
         self.relu1 = nn.ReLU(inplace=True)
         self.relu1 = nn.ReLU(inplace=True)
-        self.conv2 = nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=True)
+        self.conv2 = nn.Conv3d(
+            conv_channels, conv_channels, kernel_size=3, padding=1, bias=True
+        )
         self.relu2 = nn.ReLU(inplace=True)
         self.relu2 = nn.ReLU(inplace=True)
 
 
         self.maxpool = nn.MaxPool3d(2, 2)
         self.maxpool = nn.MaxPool3d(2, 2)

+ 31 - 17
p2ch14/nodule_analysis.py

@@ -211,7 +211,16 @@ class NoduleAnalysisApp:
         log.debug(self.cli_args.segmentation_path)
         log.debug(self.cli_args.segmentation_path)
         seg_dict = torch.load(self.cli_args.segmentation_path)
         seg_dict = torch.load(self.cli_args.segmentation_path)
 
 
-        seg_model = UNetWrapper(in_channels=7, n_classes=1, depth=3, wf=4, padding=True, batch_norm=True, up_mode='upconv')
+        seg_model = UNetWrapper(
+            in_channels=7,
+            n_classes=1,
+            depth=3,
+            wf=4,
+            padding=True,
+            batch_norm=True,
+            up_mode='upconv',
+        )
+
         seg_model.load_state_dict(seg_dict['model_state'])
         seg_model.load_state_dict(seg_dict['model_state'])
         seg_model.eval()
         seg_model.eval()
 
 
@@ -299,7 +308,10 @@ class NoduleAnalysisApp:
                 for candidateInfo_tup in getCandidateInfoList()
                 for candidateInfo_tup in getCandidateInfoList()
             )
             )
 
 
-        train_list = sorted(series_set - val_set) if self.cli_args.include_train else []
+        if self.cli_args.include_train:
+            train_list = sorted(series_set - val_set)
+        else:
+            train_list = []
         val_list = sorted(series_set & val_set)
         val_list = sorted(series_set & val_set)
 
 
 
 
@@ -314,9 +326,9 @@ class NoduleAnalysisApp:
             mask_a = self.segmentCt(ct, series_uid)
             mask_a = self.segmentCt(ct, series_uid)
 
 
             candidateInfo_list = self.groupSegmentationOutput(
             candidateInfo_list = self.groupSegmentationOutput(
-                series_uid, ct, mask_a
-            )
-            classifications_list = self.classifyCandidates(ct, candidateInfo_list)
+                series_uid, ct, mask_a)
+            classifications_list = self.classifyCandidates(
+                ct, candidateInfo_list)
 
 
             if not self.cli_args.run_validation:
             if not self.cli_args.run_validation:
                 print(f"found nodule candidates in {series_uid}:")
                 print(f"found nodule candidates in {series_uid}:")
@@ -329,11 +341,17 @@ class NoduleAnalysisApp:
                         print(s)
                         print(s)
 
 
             if series_uid in candidateInfo_dict:
             if series_uid in candidateInfo_dict:
-                one_confusion = match_and_score(classifications_list, candidateInfo_dict[series_uid])
+                one_confusion = match_and_score(
+                    classifications_list, candidateInfo_dict[series_uid]
+                )
                 all_confusion += one_confusion
                 all_confusion += one_confusion
-                print_confusion(series_uid, one_confusion, self.malignancy_model is not None)
+                print_confusion(
+                    series_uid, one_confusion, self.malignancy_model is not None
+                )
 
 
-        print_confusion("Total", all_confusion, self.malignancy_model is not None)
+        print_confusion(
+            "Total", all_confusion, self.malignancy_model is not None
+        )
 
 
 
 
     def classifyCandidates(self, ct, candidateInfo_list):
     def classifyCandidates(self, ct, candidateInfo_list):
@@ -350,14 +368,11 @@ class NoduleAnalysisApp:
                 else:
                 else:
                     probability_mal_g = torch.zeros_like(probability_nodule_g)
                     probability_mal_g = torch.zeros_like(probability_nodule_g)
 
 
-            zip_iter = zip(
-                center_list,
+            zip_iter = zip(center_list,
                 probability_nodule_g[:,1].tolist(),
                 probability_nodule_g[:,1].tolist(),
-                probability_mal_g[:,1].tolist(),
-            )
+                probability_mal_g[:,1].tolist())
             for center_irc, prob_nodule, prob_mal in zip_iter:
             for center_irc, prob_nodule, prob_mal in zip_iter:
-                center_xyz = irc2xyz(
-                    center_irc,
+                center_xyz = irc2xyz(center_irc,
                     direction_a=ct.direction_a,
                     direction_a=ct.direction_a,
                     origin_xyz=ct.origin_xyz,
                     origin_xyz=ct.origin_xyz,
                     vxSize_xyz=ct.vxSize_xyz,
                     vxSize_xyz=ct.vxSize_xyz,
@@ -369,9 +384,8 @@ class NoduleAnalysisApp:
     def segmentCt(self, ct, series_uid):
     def segmentCt(self, ct, series_uid):
         with torch.no_grad():
         with torch.no_grad():
             output_a = np.zeros_like(ct.hu_a, dtype=np.float32)
             output_a = np.zeros_like(ct.hu_a, dtype=np.float32)
-            seg_dl = self.initSegmentationDl(series_uid)
-            for batch_tup in seg_dl:
-                input_t, label_t, series_list, slice_ndx_list = batch_tup
+            seg_dl = self.initSegmentationDl(series_uid)  #  <3>
+            for input_t, _, _, slice_ndx_list in seg_dl:
 
 
                 input_g = input_t.to(self.device)
                 input_g = input_t.to(self.device)
                 prediction_g = self.seg_model(input_g)
                 prediction_g = self.seg_model(input_g)

+ 1 - 1
p2ch14/training.py

@@ -139,7 +139,7 @@ class ClassificationTrainingApp:
                 if n.split('.')[0] not in finetune_blocks:
                 if n.split('.')[0] not in finetune_blocks:
                     p.requires_grad_(False)
                     p.requires_grad_(False)
         if self.use_cuda:
         if self.use_cuda:
-            log.info("Using CUDA with {} devices.".format(torch.cuda.device_count()))
+            log.info("Using CUDA; {} devices.".format(torch.cuda.device_count()))
             if torch.cuda.device_count() > 1:
             if torch.cuda.device_count() > 1:
                 model = nn.DataParallel(model)
                 model = nn.DataParallel(model)
             model = model.to(self.device)
             model = model.to(self.device)

文件差异内容过多而无法显示
+ 19 - 21
p2ch14_malben_baseline.ipynb


+ 9 - 0
p3ch15/CMakeLists.txt

@@ -2,6 +2,15 @@ cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
 project(cyclegan-jit)
 project(cyclegan-jit)
 
 
 find_package(Torch REQUIRED)
 find_package(Torch REQUIRED)
+
+# Explicitly dealing with OpenMP seems to be needed with some C++
+# nightlies in March 2020. Probably a bug.
+find_package(OpenMP)
+if(OPENMP_FOUND)
+    set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
+    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
+endif()
+
 set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
 set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
 
 
 add_executable(cyclegan-jit cyclegan_jit.cpp)
 add_executable(cyclegan-jit cyclegan_jit.cpp)

+ 19 - 13
p3ch15/cyclegan_cpp_api.cpp

@@ -1,8 +1,8 @@
 // tag::header[]
 // tag::header[]
-#include <torch/torch.h>
+#include <torch/torch.h>  // <1>
 #define cimg_use_jpeg
 #define cimg_use_jpeg
 #include <CImg.h>
 #include <CImg.h>
-using torch::Tensor;
+using torch::Tensor;  // <2>
 // end::header[]
 // end::header[]
 
 
 // at the time of writing this code (shortly after PyTorch 1.3),
 // at the time of writing this code (shortly after PyTorch 1.3),
@@ -44,10 +44,13 @@ struct ResNetBlock : torch::nn::Module {
       : conv_block(  // <1>
       : conv_block(  // <1>
             torch::nn::ReflectionPad2d(1),
             torch::nn::ReflectionPad2d(1),
             torch::nn::Conv2d(torch::nn::Conv2dOptions(dim, dim, 3)),
             torch::nn::Conv2d(torch::nn::Conv2dOptions(dim, dim, 3)),
-            torch::nn::InstanceNorm2d(torch::nn::InstanceNorm2dOptions(dim)),
-            torch::nn::ReLU(/*inplace=*/true), torch::nn::ReflectionPad2d(1),
+            torch::nn::InstanceNorm2d(
+	       torch::nn::InstanceNorm2dOptions(dim)),
+            torch::nn::ReLU(/*inplace=*/true),
+	    torch::nn::ReflectionPad2d(1),
             torch::nn::Conv2d(torch::nn::Conv2dOptions(dim, dim, 3)),
             torch::nn::Conv2d(torch::nn::Conv2dOptions(dim, dim, 3)),
-            torch::nn::InstanceNorm2d(torch::nn::InstanceNorm2dOptions(dim))) {
+            torch::nn::InstanceNorm2d(
+	       torch::nn::InstanceNorm2dOptions(dim))) {
     register_module("conv_block", conv_block); // <2>
     register_module("conv_block", conv_block); // <2>
   }
   }
 
 
@@ -64,7 +67,7 @@ struct ResNetGeneratorImpl : torch::nn::Module {
                       int64_t ngf = 64, int64_t n_blocks = 9) {
                       int64_t ngf = 64, int64_t n_blocks = 9) {
     TORCH_CHECK(n_blocks >= 0);
     TORCH_CHECK(n_blocks >= 0);
     model->push_back(torch::nn::ReflectionPad2d(3)); // <1>
     model->push_back(torch::nn::ReflectionPad2d(3)); // <1>
-                                                     // end::generator1[]
+// end::generator1[]
     model->push_back(
     model->push_back(
         torch::nn::Conv2d(torch::nn::Conv2dOptions(input_nc, ngf, 7)));
         torch::nn::Conv2d(torch::nn::Conv2dOptions(input_nc, ngf, 7)));
     model->push_back(
     model->push_back(
@@ -79,7 +82,7 @@ struct ResNetGeneratorImpl : torch::nn::Module {
           torch::nn::Conv2dOptions(ngf * mult, ngf * mult * 2, 3)
           torch::nn::Conv2dOptions(ngf * mult, ngf * mult * 2, 3)
               .stride(2)
               .stride(2)
               .padding(1))); // <3>
               .padding(1))); // <3>
-                             // end::generator2[]
+      // end::generator2[]
       model->push_back(torch::nn::InstanceNorm2d(
       model->push_back(torch::nn::InstanceNorm2d(
           torch::nn::InstanceNorm2dOptions(ngf * mult * 2)));
           torch::nn::InstanceNorm2dOptions(ngf * mult * 2)));
       model->push_back(torch::nn::ReLU(/*inplace=*/true));
       model->push_back(torch::nn::ReLU(/*inplace=*/true));
@@ -114,7 +117,7 @@ TORCH_MODULE(ResNetGenerator); // <4>
 int main(int argc, char **argv) {
 int main(int argc, char **argv) {
   // tag::main1[]
   // tag::main1[]
   ResNetGenerator model; // <1>
   ResNetGenerator model; // <1>
-                         // end::main1[]
+  // end::main1[]
   if (argc != 3) {
   if (argc != 3) {
     std::cerr << "call as " << argv[0] << " model_weights.pt image.jpg"
     std::cerr << "call as " << argv[0] << " model_weights.pt image.jpg"
               << std::endl;
               << std::endl;
@@ -122,7 +125,7 @@ int main(int argc, char **argv) {
   }
   }
   // tag::main2[]
   // tag::main2[]
   torch::load(model, argv[1]); // <2>
   torch::load(model, argv[1]); // <2>
-                               // end::main2[]
+  // end::main2[]
   // you can print the model structure just like you would in PyTorch
   // you can print the model structure just like you would in PyTorch
   // std::cout << model << std::endl;
   // std::cout << model << std::endl;
   // tag::main3[]
   // tag::main3[]
@@ -132,12 +135,15 @@ int main(int argc, char **argv) {
       torch::tensor(torch::ArrayRef<float>(image.data(), image.size()));
       torch::tensor(torch::ArrayRef<float>(image.data(), image.size()));
   auto input = input_.reshape({1, 3, image.height(), image.width()});
   auto input = input_.reshape({1, 3, image.height(), image.width()});
   torch::NoGradGuard no_grad;          // <3>
   torch::NoGradGuard no_grad;          // <3>
+  
   model->eval();                       // <4>
   model->eval();                       // <4>
+  
   auto output = model->forward(input); // <5>
   auto output = model->forward(input); // <5>
-                                       // end::main3[]
-                                       // tag::main4[]
-  cimg_library::CImg<float> out_img(output.data_ptr<float>(), output.size(3),
-                                    output.size(2), 1, output.size(1));
+  // end::main3[]
+  // tag::main4[]
+  cimg_library::CImg<float> out_img(output.data_ptr<float>(),
+				    output.size(3), output.size(2),
+				    1, output.size(1));
   cimg_library::CImgDisplay disp(out_img, "See a C++ API zebra!"); // <6>
   cimg_library::CImgDisplay disp(out_img, "See a C++ API zebra!"); // <6>
   while (!disp.is_closed()) {
   while (!disp.is_closed()) {
     disp.wait();
     disp.wait();

+ 8 - 4
p3ch15/cyclegan_jit.cpp

@@ -4,7 +4,7 @@
 #include "CImg.h"
 #include "CImg.h"
 using namespace cimg_library;
 using namespace cimg_library;
 int main(int argc, char **argv) {
 int main(int argc, char **argv) {
-  // end::part1[]
+// end::part1[]
   if (argc != 4) {
   if (argc != 4) {
     std::cerr << "Call as " << argv[0] << " model.pt input.jpg output.jpg"
     std::cerr << "Call as " << argv[0] << " model.pt input.jpg output.jpg"
               << std::endl;
               << std::endl;
@@ -15,13 +15,17 @@ int main(int argc, char **argv) {
   image = image.resize(227, 227); // <3>
   image = image.resize(227, 227); // <3>
   // end::part2[]
   // end::part2[]
   // tag::part3[]
   // tag::part3[]
-  auto input_ = torch::tensor(torch::ArrayRef<float>(image.data(),
-                                                     image.size())); // <1>
-  auto input = input_.reshape({1, 3, image.height(), image.width()}).div_(255); // <2>
+  auto input_ = torch::tensor(
+    torch::ArrayRef<float>(image.data(), image.size())); // <1>
+  auto input = input_.reshape({1, 3, image.height(),
+			       image.width()}).div_(255); // <2>
+
   auto module = torch::jit::load(argv[1]); // <3>
   auto module = torch::jit::load(argv[1]); // <3>
+
   std::vector<torch::jit::IValue> inputs; // <4>
   std::vector<torch::jit::IValue> inputs; // <4>
   inputs.push_back(input);
   inputs.push_back(input);
   auto output_ = module.forward(inputs).toTensor(); // <5>
   auto output_ = module.forward(inputs).toTensor(); // <5>
+
   auto output = output_.contiguous().mul_(255); // <6>
   auto output = output_.contiguous().mul_(255); // <6>
 // end::part3[]
 // end::part3[]
 // tag::part4[]
 // tag::part4[]

+ 4 - 2
p3ch15/flask_server.py

@@ -10,7 +10,8 @@ from p2ch13.model_cls import LunaModel
 app = Flask(__name__)
 app = Flask(__name__)
 
 
 model = LunaModel()
 model = LunaModel()
-model.load_state_dict(torch.load(sys.argv[1], map_location='cpu')['model_state'])
+model.load_state_dict(torch.load(sys.argv[1],
+                                 map_location='cpu')['model_state'])
 model.eval()
 model.eval()
 
 
 def run_inference(in_tensor):
 def run_inference(in_tensor):
@@ -25,7 +26,8 @@ def run_inference(in_tensor):
 def predict():
 def predict():
     meta = json.load(request.files['meta'])
     meta = json.load(request.files['meta'])
     blob = request.files['blob'].read()
     blob = request.files['blob'].read()
-    in_tensor = torch.from_numpy(np.frombuffer(blob, dtype=np.float32))
+    in_tensor = torch.from_numpy(np.frombuffer(
+        blob, dtype=np.float32))
     in_tensor = in_tensor.view(*meta['shape'])
     in_tensor = in_tensor.view(*meta['shape'])
     out = run_inference(in_tensor)
     out = run_inference(in_tensor)
     return jsonify(out)
     return jsonify(out)

+ 3 - 1
p3ch15/request_batching_jit_server.py

@@ -84,7 +84,9 @@ class ModelRunner:
             batch = torch.stack([t["input"] for t in to_process], dim=0)
             batch = torch.stack([t["input"] for t in to_process], dim=0)
             # we could delete inputs here...
             # we could delete inputs here...
 
 
-            result = await app.loop.run_in_executor(None, functools.partial(self.run_model, batch))
+            result = await app.loop.run_in_executor(
+                None, functools.partial(self.run_model, batch)
+            )
             for t, r in zip(to_process, result):
             for t, r in zip(to_process, result):
                 t["output"] = r
                 t["output"] = r
                 t["done_event"].set()
                 t["done_event"].set()

+ 10 - 2
p3ch15/request_batching_server.py

@@ -33,9 +33,14 @@ class ModelRunner:
     def __init__(self, model_name):
     def __init__(self, model_name):
         self.model_name = model_name
         self.model_name = model_name
         self.queue = []
         self.queue = []
+
         self.queue_lock = None
         self.queue_lock = None
-        self.model = get_pretrained_model(self.model_name, map_location=device)
+
+        self.model = get_pretrained_model(self.model_name,
+                                          map_location=device)
+
         self.needs_processing = None
         self.needs_processing = None
+
         self.needs_processing_timer = None
         self.needs_processing_timer = None
 
 
     def schedule_processing_if_needed(self):
     def schedule_processing_if_needed(self):
@@ -56,6 +61,7 @@ class ModelRunner:
             self.queue.append(our_task)
             self.queue.append(our_task)
             logger.debug("enqueued task. new queue size {}".format(len(self.queue)))
             logger.debug("enqueued task. new queue size {}".format(len(self.queue)))
             self.schedule_processing_if_needed()
             self.schedule_processing_if_needed()
+
         await our_task["done_event"].wait()
         await our_task["done_event"].wait()
         return our_task["output"]
         return our_task["output"]
 
 
@@ -85,7 +91,9 @@ class ModelRunner:
             batch = torch.stack([t["input"] for t in to_process], dim=0)
             batch = torch.stack([t["input"] for t in to_process], dim=0)
             # we could delete inputs here...
             # we could delete inputs here...
 
 
-            result = await app.loop.run_in_executor(None, functools.partial(self.run_model, batch))
+            result = await app.loop.run_in_executor(
+                None, functools.partial(self.run_model, batch)
+            )
             for t, r in zip(to_process, result):
             for t, r in zip(to_process, result):
                 t["output"] = r
                 t["output"] = r
                 t["done_event"].set()
                 t["done_event"].set()

部分文件因为文件数量过多而无法显示