{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%matplotlib inline\n", "from matplotlib import pyplot as plt\n", "import numpy as np\n", "import torch\n", "\n", "torch.set_printoptions(edgeitems=2)\n", "torch.manual_seed(123)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class_names = ['airplane','automobile','bird','cat','deer',\n", " 'dog','frog','horse','ship','truck']" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from torchvision import datasets, transforms\n", "data_path = '../data-unversioned/p1ch7/'\n", "cifar10 = datasets.CIFAR10(\n", " data_path, train=True, download=False,\n", " transform=transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.4915, 0.4823, 0.4468),\n", " (0.2470, 0.2435, 0.2616))\n", " ]))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "cifar10_val = datasets.CIFAR10(\n", " data_path, train=False, download=False,\n", " transform=transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.4915, 0.4823, 0.4468),\n", " (0.2470, 0.2435, 0.2616))\n", " ]))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "label_map = {0: 0, 2: 1}\n", "class_names = ['airplane', 'bird']\n", "cifar2 = [(img, label_map[label])\n", " for img, label in cifar10 \n", " if label in [0, 2]]\n", "cifar2_val = [(img, label_map[label])\n", " for img, label in cifar10_val\n", " if label in [0, 2]]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "\n", "n_out = 2\n", "\n", "model = nn.Sequential(\n", " nn.Linear(\n", " 3072, # <1>\n", " 512, # <2>\n", " ),\n", " nn.Tanh(),\n", " nn.Linear(\n", " 512, # <2>\n", " n_out, # <3>\n", " )\n", " )" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def softmax(x):\n", " return torch.exp(x) / torch.exp(x).sum()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.0900, 0.2447, 0.6652])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = torch.tensor([1.0, 2.0, 3.0])\n", "\n", "softmax(x)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(1.)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "softmax(x).sum()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.0900, 0.2447, 0.6652],\n", " [0.0900, 0.2447, 0.6652]])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "softmax = nn.Softmax(dim=1)\n", "\n", "x = torch.tensor([[1.0, 2.0, 3.0],\n", " [1.0, 2.0, 3.0]])\n", "\n", "softmax(x)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " nn.Linear(3072, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 2),\n", " nn.Softmax(dim=1))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAGXFJREFUeJztnXt4VeWVxt8lAUEDRQhICtiAYos3Lg14G6yCqPg4g1i1OtWh1hHt6Iyd6Uyljq20tR3tU7XaUgWrI/YRFeu99YYpLVQZJCImYqxcREECJGKEiBhD1vxxDhriXisn+5yzT9Lv/T1PniTrzbf3OvucN/ucvfa3PlFVEELCY59CJ0AIKQw0PyGBQvMTEig0PyGBQvMTEig0PyGBQvMTEig0PyGBQvMTEihF2QwWkdMA3AKgG4DfqOr13t/3LhHtXxatvfWGM3Df6PA+Pe0h3aWbqfXaz37YBxSXmFofDIyMFzn/Qxux3dQ27lhjasW97TsvP2cqQHcjvssZYxxeAP7Zwbs3dLcR398Zkw8ajHiTM+ZD8ygC3qPevqPZ1Jo+dDa509Es3jfiuwFtUclkExL39l4R6QbgDQCTAWwEsBzA+ar6mjWmrFz0+5XR2j+f4uxsWHS4z0jbxAOKbIuMPqq/qU2bcLGpTZbLI+MDnZf0C3jO1L5bcYapHTvpI1OzRwEDjPhqZ4xxeAEAxY7m/UNpNOLjnTFxaXG0J4z4emdMDQabWjNsgz9XscXU3qpxdviyo1k8acTrAf04M/Nn87Z/PIA1qrpOVZsA3A9gahbbI4QkSDbmHwxgQ6vfN6ZjhJAuQDbmj3pr8ZnPECIyQ0QqRaRyR10WeyOE5JRszL8RwNBWvw8BsKntH6nqXFUtV9Xy3tYHUkJI4mRj/uUARojIMBHpAeA8AI/nJi1CSL6JXepT1WYRuQLAM0iV+u5S1VXemG5wrh5/1Rl4WXR4+0j7yuv2I981tTc329qaF+fY2sTow3X+2JPMMcc5hblfT+praqthXzk2CiYA7CvwQ5wx9lEE6h3Nzj7uVf3hjnaUqVRiuan9z2PvRMaLh0aGU5TYj7riVrsK02OMs007Rbse6WE90R0o3mVV51fVJ2EXHQghnRje4UdIoND8hAQKzU9IoND8hAQKzU9IoGR1tb+jbAHwS0ObdKk9rsKoD44a49VW7HLeK99729YeX2drU74TGa+eZZehJo+vMjWvwuNMWMRGR7MqSlOcMYMc7VhH64MDHdU6/l5h0S6jvYAzTW3hY3Y5ddmZ86KFs+wsjv6VnQeOtKWmF20Nbzqa5cJFzpgcwDM/IYFC8xMSKDQ/IYFC8xMSKDQ/IYGS6NX+D94Bnr8mWjv4Onvc5V+Pjs9+yOl/5LVNGudo3rzEp6LDSy+3r+j/vbM572r/TY7mMdmIe9fYvTZefdz+LNE9DQHgMuPRTcah5phxTjfBOqehWPXQC0wNMK72OwfkS6W21jDB1v7qXdF3tlmo2TE88xMSKDQ/IYFC8xMSKDQ/IYFC8xMSKDQ/IYGSaKkPmwH8JFpa6wyb/Q+G4M1+cRrMHeysDrTWKxHOjw5vchrdfWOxsz0n/4Exl7axFhuzSoAAcKi7AJi97NkfPtus+RN2GU/AMNiTsZbC7oV4ntfkcawtmY+8bKE5YqG55hSwabazK2/G1QZHs5Y3yjM88xMSKDQ/IYFC8xMSKDQ/IYFC8xMSKDQ/IYGSValPRNYD2AFgN4BmVS2PvbF7Hc0qvzn91PBNW2p2lmo6+le2tsxaImm1k4eHM4Ow8WZb+5eDbM1aCHmJk0Y13je1Mkdb6WxznNHf7z38yRxzH86xNxi1JnRGnB0dbt7fHLFp9qP25ryZewc4WidcoToXdf6TVNVb0o0Q0gnh235CAiVb8yuAZ0XkJRGZkYuECCHJkO3b/uNVdZOIDASwUEReV9W9bmhN/1PgPwZCOhlZnflVdVP6+1YAjyBiWXZVnauq5VldDCSE5JzY5heR/UWk956fAZwC4NVcJUYIyS+iqvEGigxH6mwPpD4+zFdVY87eJ2Pi7ezHRvweZ4y3zpQzm+6Lc2ztEiN+hLOreqcB5owX3zG1ndX2NkddbGvWBDFv1uTxjvYtR7NmEAJAKaLrkdXYbY65oMqpYY5yOmfiZ44WA++BTXQ0u8cosNTRrBl/MWf7qWpGhdHYn/lVdR2AUXHHE0IKC0t9hAQKzU9IoND8hAQKzU9IoND8hARK7FJfrJ3FLfVNMeK9nTHeTLv3HM1qFgoAJxhxZ+2/rzrVK6fHKO5c4YheM0ij8efhzlpxzlw6t4zpTI5EmRGvQX9zzIkVh9kbPNmaUgkAy23JKr+NdDZ3rqN5DV5rHc1Y5zEfZFrq45mfkECh+QkJFJqfkECh+QkJFJqfkEBJdrkuDy+TNUb8n2Lua4GjPexo1gpPZfaQhy53tucsybWPvaoVBjnLU40w4hc6aRziaB5eWzpLa8a79qCFh8ZL5G7nar/B1Om2NsgZN+dSR7RXAOuU8MxPSKDQ/IQECs1PSKDQ/IQECs1PSKDQ/IQESucp9TU72g4j7k2y8HB6+LlHZLIR9yYYbXa0eU4aV9rahO7ONg28JZWsh9UecQ6/Nz0HN9iTfuD0Qrx9+mmmdgiejox7x2O9o7kDvddwJ4RnfkICheYnJFBofkICheYnJFBofkICheYnJFDaLfWJyF0AzgCwVVWPSMf6AXgAqfls6wGcq6peZ7zssMpl850xzqw4c+ob4Pf+O8CIe7MLvRmEznJMTU6fvvXDbW2IEe+JA80xT2GLqXl9Bp9wNGsipo+3N3u9qzKnh5/1lHn5NcOeXTjqyjdM7ZUjnY3+0NGs16pXkh5gxP/sjGlDJmf+uwG0LaTOBFChqiMAVKR/J4R0Ido1v6ouBrCtTXgqPr1FZR6AM3OcFyEkz8T9zH+gqtYCQPr7wNylRAhJgrzf3isiMwDMyPd+CCEdI+6Zf4uIlAJA+vtW6w9Vda6qlqtqecx9EULyQFzzPw5gTxe06QAey006hJCkaHe5LhG5D8CJAEoAbAFwLYBHkSpiHQTgbQDnqGrbi4JR20pubTAPr0OjNwvPKuV4H2p6OZoznc5b5usc7O9sNHp9qhKn1FePKlP7P2dPv/jQEW804vc4Y1b/xtZG2p1Qv/baR6Y2wYh/CUeZY8bhFlNrxhxTK4L9pC3Hwaa22Tj+jbDLiq/r2sj4/HEbsaXyo4yW62r3M7+qnm9IkzLZASGkc8I7/AgJFJqfkECh+QkJFJqfkECh+QkJlM7TwLOz4M2kqjbi33PG/NyWpjvlPGs2GgCsh93ossQoKRU5T/VKZ1+/8O7g+KOjWY0uvVmTWGdLk+0npgF2qc+q3BY75c0ifNfUSvG2qR1q1jeBSfi6qVmtULfhHXNEPzk5Mr4Emd9LxzM/IYFC8xMSKDQ/IYFC8xMSKDQ/IYFC8xMSKF271Odl762b1uBo7mJyBk4jTr+0Zdf6ijDN2Z3dKXKIMVutHu+aY5asMCVg6UJby/m6dT81laMP2NfU/t3ZopWi18BzidMQ1Ht5/AgXmNpwZ5vA5yOjy/GBOeJUnORsLzN45ickUGh+QgKF5ickUGh+QgKF5ickULr21f5YV5QR74p+XKJb6qWk9+yr2x/uOs/UDinpZm/UeEaLnIrEJWPbLsj0KReNtcetsZs2o/qZ6Ek6f1hwg71BPGoqE5rtyTunftJL9rPc+MnaMnsTZyUsAKh1tDcdbYjTF9B6arz+iZUf3h0Zr23xmlDuDc/8hAQKzU9IoND8hAQKzU9IoND8hAQKzU9IoLRb6hORuwCcAWCrqh6Rjs0CcAmAuvSfXa2qT+YryU6PU87r1/gDU3twtr2E04C+djmvYYS9v0aj0rNmtV0qKxthT5rp2dfe14SJ9srsg46L1p4660pzTMvDdqlvqVNHe80o5wHAaCM+DIPNMRuc3nm9Hcs0w37O/tfpMzjEiE8xRwA9e0VP4Jq/z/vOqL3J5Mx/N4CoQvDNqjo6/RWu8QnporRrflVdDKDdRTgJIV2LbD7zXyEiVSJyl4h4naYJIZ2QuOa/DcDBSH2kqoW9IDNEZIaIVIpIZcx9EULyQCzzq+oWVd2tqi0A7gAw3vnbuaparqqZryZACMk7scwvIqWtfp0G4NXcpEMISYpMSn33ATgRQImIbARwLYATRWQ0AAWwHsCleczR5PNldolq2AnmmxEU7bIf9p8XLOp4IsP+w5S2vTnBHlf3liltHbG/qdVutstU26rfiBaqVpljVjXaveLQaJeOHho3xtR6jIkuY7Y87PQEdHjeWioNwK+dcb2NeJ1TzhvpbG+yM5W0r6N5bSOtjozj4c2AvCwy2gtfccbsTbvmV9XzI8J3ZrwHQkinhHf4ERIoND8hgULzExIoND8hgULzExIoiTbwHNz/8/jXqdElip4n2GWjnmMOi4yfNGy4OabYqvHAnYSHswZdYWoVt94fLVjlNQCofttJxC7nod5eQ2tb3YHOuOjGmXBKW0B/R3PW8lpiz1hsWmJt83POvhycUp/Xj/VpI772OmeQ16XTaWh66cW29hdnk1b+xzkNTe1EnLJtG3jmJyRQaH5CAoXmJyRQaH5CAoXmJyRQaH5CAiXRUt+gslJcdef3k9xlh6mpdxa1w7tG/PfxdubtqsYrv51tS32PjY43OGVFOOVIZz0+H+tYWfH4eGvamS9w75XvTRN0pvzNcZqdmlP3AKwaFh1/ovtSc8xPMDkyvt1JoS088xMSKDQ/IYFC8xMSKDQ/IYFC8xMSKIle7e8K1Fd7kymSxLsqPseWGqw+cnZ/OcCYsNSZcF6pqx5zxp0QHf7yTHvISxuc7RnLoQEAvHGnd3zcSxvtIbcZj6sjtRme+QkJFJqfkECh+QkJFJqfkECh+QkJFJqfkEDJZLmuoQDuATAIQAuAuap6i4j0A/AAgDKkluw6V1Xfy1+qHaMJm0yth1NGK6q2l6dqyiqjpPgbXUxphqMNdTSjwlntvFK/6PT367vD1mqcMmDPXra21eg3ebjd1hK7PoyOa4s9pi2ZnPmbAXxHVUcCOAbA5SJyGICZACpUdQSAivTvhJAuQrvmV9VaVV2R/nkHgBoAgwFMBTAv/WfzAJyZryQJIbmnQ5/5RaQMwBgAywAcqKq1QOofBICBuU6OEJI/Mja/iBQDeAjAt1U1454BIjJDRCpFpLKuri5OjoSQPJCR+UWkO1LGv1dVH06Ht4hIaVovhXFbsarOVdVyVS0fMGBALnImhOSAds0vIoLUJeQaVb2plfQ4gOnpn6cD8KZXEEI6GZnM6jsewIUAqkVkZTp2NYDrASwQkYuRagJ3Tn5SBLYZ8UZYS1MBDfqcqQ3CFlPbmWlSJFGOnm1ry56xtT7GqlbeC7/WaWl40UFHmdq0g6pMzZtTeY0x7JhJ9hirJeBrHbiK1675VfUvAMSQnfQIIZ0Z3uFHSKDQ/IQECs1PSKDQ/IQECs1PSKB0iQae/Yx4MYabYzb/8R1Te6p+iantV2znsdNbXotkz5SY4162pQNOjY5700+nHWRr52BfU+vpbHORox0/MTruTVa874Xo+LYOvEZ55ickUGh+QgKF5ickUGh+QgKF5ickUGh+QgKlS5T64lAybLCplU20OyOOqbbLgM//JHpu1pevsvN4yZb82tBqR5vvbTRBjnW0pTG2d40tTcbnTG30TPtlvMZo1rpc7X3tsqaxAbgJy03Nq1Q6y+5hgrG/OifHDW9Gx5s60GWWZ35CAoXmJyRQaH5CAoXmJyRQaH5CAiXRq/0tsHvkNRrLDwFAX2OpoyJ8YI4ZPtye9NO4Y7GpWVf0PWrmOOLpjuZ1Mh/R4TSSpyHGmCGO5iyFdd1Eexk1jHS2aVQQ9nEmcD1gXEkHADgTZ54+ztZOczY5wYg3OFWHhrOi40/d6OyoDTzzExIoND8hgULzExIoND8hgULzExIoND8hgdJuqU9EhgK4B8AgpKp1c1X1FhGZBeASfFqwulpVn/S2tQ+A/Qyt3ikb9TBKfVvxe3PMgw+cZ2pX2JL737DFiO/0Sl5xJ+EsjDkuSTpeFQWM5xIA8A1H2+xoXoO88dHhFq+Jnzcp6VxbWvtzW5ttzxfDGGOVy4tgT06r7hXdo7JbLpfrQuop/o6qrhCR3gBeEpE9L82bVdV5yISQzkoma/XVAqhN/7xDRGoA518SIaRL0KHP/CJSBmAMgGXp0BUiUiUid4nIATnOjRCSRzI2v4gUA3gIwLdVdTuA2wAcDGA0Uu8MIm8sFJEZIlIpIpV1dd79rISQJMnI/CLSHSnj36uqDwOAqm5R1d2q2gLgDhiXVlR1rqqWq2r5gAEDcpU3ISRL2jW/iAiAOwHUqOpNreKlrf5sGoBXc58eISRfZHK1/3gAFwKoFpGV6djVAM4XkdEAFMB6AJe2t6FG7MILqInUajesNcdVvxYd/+0iu2b3wLPtZRONVc7rVPybo92a431da0s9jrS1prMNwetNGBen/GbOFKx3xixwNK+YHXM5t18aM1qPNMp5AHBHVXT8Y2d2bFsyudr/FwBRkwvdmj4hpHPDO/wICRSan5BAofkJCRSan5BAofkJCZREG3hu/7geC2vvjtSqX7jXHNe4OrrksehlZ2deE8YuzqSv2VpFrkt982ypyZvNOM6I26tdxWeYow014t1j7itmOc9ryPqK8Tp+xGkIWmI8rroemafEMz8hgULzExIoND8hgULzExIoND8hgULzExIoiZb6PvpgG1a/GF3SK+ptz2AqMZowTnDWaKv4L1vrY0vY7mgWp06xtWeeirFBAJOsUhmAMWNsrcKa8Re3BLje0fo6mlXa8pp+eqVbjziNRE9ytH90tLgNWb3ZjMZaj9c7axeOOjU6/l63jDPimZ+QUKH5CQkUmp+QQKH5CQkUmp+QQKH5CQmUZEt973+MNU9Gl/SKrdlXADYaWQ5yymFTH7W1eqd5Y4OTx67F0fGlccs/DhXO7LeKmc5Aozv6frfbQ3bOcrbnNMc8/Ou2dohRnu3p7OoRY806AGjyOkYOcTTrud7ljHFKyHnBKnE6B6vamMnY4j2uNvDMT0ig0PyEBArNT0ig0PyEBArNT0igtHu1X0R6AlgMYN/03/9OVa8VkWEA7gfQD8AKABeqapO3rf49gYuMCR+vO1fZrfkjzc6i4IPG2trmFbZW41y5b4lch7gAeH3k7okO73T67X35x7a2wVlYedUNjnZKdHw/Z1LST6faWo2jPae29pa19FatPQaljuZUmGL393uv40OKjErAxx04nWfypx8BmKiqo5Bajvs0ETkGwA0AblbVEUilf3HmuyWEFJp2za8p9vxP657+UgATAfwuHZ8H4My8ZEgIyQsZvUkQkW7pFXq3AlgIYC2ABlXdM5N6I4DB+UmREJIPMjK/qu5W1dFI3Us1HtH3QEV+8hKRGSJSKSKVjXE/ExFCck6HrvaragOAPwE4BkBfEdlzwXAIgE3GmLmqWq6q5cXF2aRKCMkl7ZpfRAaISN/0z70AnAygBsAiAGen/2w6AOfObEJIZ0NUnToJABE5CqkLet2Q+mexQFV/JCLD8Wmp72UAF6jqR962RpeKPmvUBDZeta857r4bozc735ns0eBMzujrlHkaFtraTlvKPSWO5pWbYvYMNJngaM4kkj5GqW+7s4zaF75pa2dMsjVjLhMA4NZ10fFttziDnDIxnNKnu0Scl+TDRtzrTWhN1JoB6OsqzshPaLfOr6pVAD5TnVXVdUh9/ieEdEF4hx8hgULzExIoND8hgULzExIoND8hgdJuqS+nOxOpA/BW+tcS2B3WkoR57A3z2JuulscXVNUrLH5Coubfa8cilapaXpCdMw/mwTz4tp+QUKH5CQmUQpp/bgH33RrmsTfMY2/+ZvMo2Gd+Qkhh4dt+QgKlIOYXkdNE5K8iskZEvMWn8p3HehGpFpGVIlKZ4H7vEpGtIvJqq1g/EVkoIqvT3532pHnNY5aIvJM+JitF5PQE8hgqIotEpEZEVonIlel4osfEySPRYyIiPUXkRRF5JZ3HD9PxYSKyLH08HhCRHlntSFUT/UJqavBaAMMB9ADwCoDDks4jnct6ACUF2O8JSE0cfbVV7GcAZqZ/ngnghgLlMQvAfyZ8PEoBjE3/3BvAGwAOS/qYOHkkekwACIDi9M/dASxDqoHOAgDnpeO3A/hWNvspxJl/PIA1qrpOU62+7wfgNGb+20NVFwPY1iY8Fam+CUBCDVGNPBJHVWtVdUX65x1INYsZjISPiZNHomiKvDfNLYT5BwPY0Or3Qjb/VADPishLIjKjQDns4UBVrQVSL0IAAwuYyxUiUpX+WJD3jx+tEZEypPpHLEMBj0mbPICEj0kSTXMLYf6oLiOFKjkcr6pjAUwBcLmInFCgPDoTtwE4GKk1GmoBJLZUiYgUA3gIwLdVdXtS+80gj8SPiWbRNDdTCmH+jQBar89jNv/MN6q6Kf19K4BHUNjORFtEpBQA0t+3FiIJVd2SfuG1ALgDCR0TEemOlOHuVdU9ja0SPyZReRTqmKT33eGmuZlSCPMvBzAifeWyB4DzADyedBIisr+I9N7zM4BTALzqj8orjyPVCBUoYEPUPWZLMw0JHBMREQB3AqhR1ZtaSYkeEyuPpI9JYk1zk7qC2eZq5ulIXUldC+C/C5TDcKQqDa8AWJVkHgDuQ+rt48dIvRO6GEB/ABUAVqe/9ytQHr8FUA2gCinzlSaQx98h9Ra2CsDK9NfpSR8TJ49EjwmAo5BqiluF1D+aH7R6zb4IYA2ABwHsm81+eIcfIYHCO/wICRSan5BAofkJCRSan5BAofkJCRSan5BAofkJCRSan5BA+X+oAC6reFaYfAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "img, _ = cifar2[0]\n", "\n", "plt.imshow(img.permute(1, 2, 0))\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "img_batch = img.view(-1).unsqueeze(0)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.4784, 0.5216]], grad_fn=)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out = model(img_batch)\n", "out" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([1])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "_, index = torch.max(out, dim=1)\n", "\n", "index" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[1., 0.],\n", " [1., 0.],\n", " [0., 1.],\n", " [0., 1.]])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out = torch.tensor([\n", " [0.6, 0.4],\n", " [0.9, 0.1],\n", " [0.3, 0.7],\n", " [0.2, 0.8],\n", "])\n", "class_index = torch.tensor([0, 0, 1, 1]).unsqueeze(1)\n", "\n", "truth = torch.zeros((4,2))\n", "truth.scatter_(dim=1, index=class_index, value=1.0)\n", "truth" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.1500)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def mse(out):\n", " return ((out - truth) ** 2).sum(dim=1).mean()\n", "mse(out)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.6000],\n", " [0.9000],\n", " [0.7000],\n", " [0.8000]])" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out.gather(dim=1, index=class_index)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.3024])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def likelihood(out):\n", " prod = 1.0\n", " for x in out.gather(dim=1, index=class_index):\n", " prod *= x\n", " return prod\n", "\n", "likelihood(out)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([1.1960])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def neg_log_likelihood(out):\n", " return -likelihood(out).log()\n", "\n", "neg_log_likelihood(out)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.0750, 0.1500, 0.2500, 0.4750])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out0 = out.clone().detach()\n", "out0[0] = torch.tensor([0.9, 0.1]) # more right\n", "\n", "out2 = out.clone().detach()\n", "out2[0] = torch.tensor([0.4, 0.6]) # slightly wrong\n", "\n", "out3 = out.clone().detach()\n", "out3[0] = torch.tensor([0.1, 0.9]) # very wrong\n", "\n", "mse_comparison = torch.tensor([mse(o) for o in [out0, out, out2, out3]])\n", "mse_comparison" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-50.0000, 0.0000, 66.6667, 216.6667])" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "((mse_comparison / mse_comparison[1]) - 1) * 100" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.7905, 1.1960, 1.6015, 2.9878])" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nll_comparison = torch.tensor([neg_log_likelihood(o) \n", " for o in [out0, out, out2, out3]])\n", "nll_comparison" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-33.9016, 0.0000, 33.9016, 149.8121])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "((nll_comparison / nll_comparison[1]) - 1) * 100" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 1.]])" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "softmax = nn.Softmax(dim=1)\n", "\n", "log_softmax = nn.LogSoftmax(dim=1)\n", "\n", "x = torch.tensor([[0.0, 104.0]])\n", "\n", "softmax(x)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 1.]])" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "softmax = nn.Softmax(dim=1)\n", "\n", "log_softmax = nn.LogSoftmax(dim=1)\n", "\n", "x = torch.tensor([[0.0, 104.0]])\n", "\n", "softmax(x)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-inf, 0.]])" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.log(softmax(x))" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-104., 0.]])" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "log_softmax(x)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 1.]])" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.exp(log_softmax(x))" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " nn.Linear(3072, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 2),\n", " nn.LogSoftmax(dim=1))" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "loss = nn.NLLLoss()" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.5077, grad_fn=)" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img, label = cifar2[0]\n", "\n", "out = model(img.view(-1).unsqueeze(0))\n", "\n", "loss(out, torch.tensor([label]))" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0, Loss: 5.347057\n", "Epoch: 1, Loss: 7.705317\n", "Epoch: 2, Loss: 6.510838\n", "Epoch: 3, Loss: 9.557189\n", "Epoch: 4, Loss: 4.151933\n", "Epoch: 5, Loss: 5.636873\n", "Epoch: 6, Loss: 6.531207\n", "Epoch: 7, Loss: 20.450516\n", "Epoch: 8, Loss: 5.072948\n", "Epoch: 9, Loss: 4.941860\n", "Epoch: 10, Loss: 6.445535\n", "Epoch: 11, Loss: 4.580799\n", "Epoch: 12, Loss: 6.660308\n", "Epoch: 13, Loss: 9.436373\n", "Epoch: 14, Loss: 16.786476\n", "Epoch: 15, Loss: 8.349138\n", "Epoch: 16, Loss: 8.176860\n", "Epoch: 17, Loss: 5.862664\n", "Epoch: 18, Loss: 8.218906\n", "Epoch: 19, Loss: 13.296558\n", "Epoch: 20, Loss: 7.313433\n", "Epoch: 21, Loss: 4.585245\n", "Epoch: 22, Loss: 11.706884\n", "Epoch: 23, Loss: 18.208710\n", "Epoch: 24, Loss: 0.343157\n", "Epoch: 25, Loss: 9.255491\n", "Epoch: 26, Loss: 10.466807\n", "Epoch: 27, Loss: 12.226366\n", "Epoch: 28, Loss: 12.728527\n", "Epoch: 29, Loss: 9.777843\n", "Epoch: 30, Loss: 6.128856\n", "Epoch: 31, Loss: 13.284330\n", "Epoch: 32, Loss: 10.321814\n", "Epoch: 33, Loss: 2.928349\n", "Epoch: 34, Loss: 8.623670\n", "Epoch: 35, Loss: 12.719531\n", "Epoch: 36, Loss: 4.030444\n", "Epoch: 37, Loss: 4.621825\n", "Epoch: 38, Loss: 13.210777\n", "Epoch: 39, Loss: 14.217413\n", "Epoch: 40, Loss: 3.880259\n", "Epoch: 41, Loss: 13.189833\n", "Epoch: 42, Loss: 17.787762\n", "Epoch: 43, Loss: 3.953930\n", "Epoch: 44, Loss: 0.640078\n", "Epoch: 45, Loss: 9.262226\n", "Epoch: 46, Loss: 7.383645\n", "Epoch: 47, Loss: 5.352252\n", "Epoch: 48, Loss: 11.515299\n", "Epoch: 49, Loss: 12.266010\n", "Epoch: 50, Loss: 12.210896\n", "Epoch: 51, Loss: 3.987965\n", "Epoch: 52, Loss: 12.570765\n", "Epoch: 53, Loss: 13.025002\n", "Epoch: 54, Loss: 13.747946\n", "Epoch: 55, Loss: 6.783926\n", "Epoch: 56, Loss: 11.822943\n", "Epoch: 57, Loss: 8.200066\n", "Epoch: 58, Loss: 9.206728\n", "Epoch: 59, Loss: 7.715425\n", "Epoch: 60, Loss: 5.571069\n", "Epoch: 61, Loss: 13.017315\n", "Epoch: 62, Loss: 10.307802\n", "Epoch: 63, Loss: 2.660404\n", "Epoch: 64, Loss: 11.096642\n", "Epoch: 65, Loss: 5.284830\n", "Epoch: 66, Loss: 8.374750\n", "Epoch: 67, Loss: 1.418676\n", "Epoch: 68, Loss: 9.891462\n", "Epoch: 69, Loss: 9.079073\n", "Epoch: 70, Loss: 6.453581\n", "Epoch: 71, Loss: 8.293860\n", "Epoch: 72, Loss: 4.585221\n", "Epoch: 73, Loss: 14.174129\n", "Epoch: 74, Loss: 6.072280\n", "Epoch: 75, Loss: 5.925417\n", "Epoch: 76, Loss: 0.260600\n", "Epoch: 77, Loss: 3.055498\n", "Epoch: 78, Loss: 0.347163\n", "Epoch: 79, Loss: 3.497080\n", "Epoch: 80, Loss: 6.615281\n", "Epoch: 81, Loss: 8.944511\n", "Epoch: 82, Loss: 10.230938\n", "Epoch: 83, Loss: 6.776264\n", "Epoch: 84, Loss: 10.169885\n", "Epoch: 85, Loss: 7.014330\n", "Epoch: 86, Loss: 3.467798\n", "Epoch: 87, Loss: 3.772486\n", "Epoch: 88, Loss: 13.495383\n", "Epoch: 89, Loss: 11.781836\n", "Epoch: 90, Loss: 6.853724\n", "Epoch: 91, Loss: 3.313806\n", "Epoch: 92, Loss: 7.867707\n", "Epoch: 93, Loss: 16.117371\n", "Epoch: 94, Loss: 15.077475\n", "Epoch: 95, Loss: 17.807060\n", "Epoch: 96, Loss: 16.376089\n", "Epoch: 97, Loss: 9.348265\n", "Epoch: 98, Loss: 18.044790\n", "Epoch: 99, Loss: 15.565783\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "\n", "model = nn.Sequential(\n", " nn.Linear(3072, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 2),\n", " nn.LogSoftmax(dim=1))\n", "\n", "learning_rate = 1e-2\n", "\n", "optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n", "\n", "loss_fn = nn.NLLLoss()\n", "\n", "n_epochs = 100\n", "\n", "for epoch in range(n_epochs):\n", " for img, label in cifar2:\n", " out = model(img.view(-1).unsqueeze(0))\n", " loss = loss_fn(out, torch.tensor([label]))\n", " \n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " print(\"Epoch: %d, Loss: %f\" % (epoch, float(loss)))" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n", " shuffle=True)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0, Loss: 0.604063\n", "Epoch: 1, Loss: 0.597974\n", "Epoch: 2, Loss: 0.271415\n", "Epoch: 3, Loss: 0.451056\n", "Epoch: 4, Loss: 0.629758\n", "Epoch: 5, Loss: 0.458762\n", "Epoch: 6, Loss: 0.277813\n", "Epoch: 7, Loss: 0.406921\n", "Epoch: 8, Loss: 0.951961\n", "Epoch: 9, Loss: 0.433738\n", "Epoch: 10, Loss: 0.351960\n", "Epoch: 11, Loss: 0.355687\n", "Epoch: 12, Loss: 0.518611\n", "Epoch: 13, Loss: 0.262623\n", "Epoch: 14, Loss: 0.221969\n", "Epoch: 15, Loss: 0.774132\n", "Epoch: 16, Loss: 0.324406\n", "Epoch: 17, Loss: 0.447701\n", "Epoch: 18, Loss: 0.299780\n", "Epoch: 19, Loss: 0.267090\n", "Epoch: 20, Loss: 0.279828\n", "Epoch: 21, Loss: 0.197123\n", "Epoch: 22, Loss: 0.196783\n", "Epoch: 23, Loss: 0.328715\n", "Epoch: 24, Loss: 0.334952\n", "Epoch: 25, Loss: 0.500689\n", "Epoch: 26, Loss: 0.186956\n", "Epoch: 27, Loss: 0.138649\n", "Epoch: 28, Loss: 0.239988\n", "Epoch: 29, Loss: 0.495020\n", "Epoch: 30, Loss: 0.251347\n", "Epoch: 31, Loss: 0.088298\n", "Epoch: 32, Loss: 0.175127\n", "Epoch: 33, Loss: 0.208338\n", "Epoch: 34, Loss: 0.145656\n", "Epoch: 35, Loss: 0.129570\n", "Epoch: 36, Loss: 0.200110\n", "Epoch: 37, Loss: 0.133076\n", "Epoch: 38, Loss: 0.230561\n", "Epoch: 39, Loss: 0.241688\n", "Epoch: 40, Loss: 0.106870\n", "Epoch: 41, Loss: 0.281168\n", "Epoch: 42, Loss: 0.175034\n", "Epoch: 43, Loss: 0.073779\n", "Epoch: 44, Loss: 0.171294\n", "Epoch: 45, Loss: 0.112456\n", "Epoch: 46, Loss: 0.132553\n", "Epoch: 47, Loss: 0.048826\n", "Epoch: 48, Loss: 0.076014\n", "Epoch: 49, Loss: 0.122317\n", "Epoch: 50, Loss: 0.103442\n", "Epoch: 51, Loss: 0.201585\n", "Epoch: 52, Loss: 0.145637\n", "Epoch: 53, Loss: 0.055844\n", "Epoch: 54, Loss: 0.046278\n", "Epoch: 55, Loss: 0.081562\n", "Epoch: 56, Loss: 0.058857\n", "Epoch: 57, Loss: 0.197200\n", "Epoch: 58, Loss: 0.044184\n", "Epoch: 59, Loss: 0.043374\n", "Epoch: 60, Loss: 0.032936\n", "Epoch: 61, Loss: 0.072488\n", "Epoch: 62, Loss: 0.060811\n", "Epoch: 63, Loss: 0.029262\n", "Epoch: 64, Loss: 0.036435\n", "Epoch: 65, Loss: 0.058120\n", "Epoch: 66, Loss: 0.063329\n", "Epoch: 67, Loss: 0.020670\n", "Epoch: 68, Loss: 0.077189\n", "Epoch: 69, Loss: 0.060933\n", "Epoch: 70, Loss: 0.070848\n", "Epoch: 71, Loss: 0.036434\n", "Epoch: 72, Loss: 0.084855\n", "Epoch: 73, Loss: 0.044776\n", "Epoch: 74, Loss: 0.037828\n", "Epoch: 75, Loss: 0.024554\n", "Epoch: 76, Loss: 0.018965\n", "Epoch: 77, Loss: 0.033381\n", "Epoch: 78, Loss: 0.016183\n", "Epoch: 79, Loss: 0.020083\n", "Epoch: 80, Loss: 0.041192\n", "Epoch: 81, Loss: 0.015122\n", "Epoch: 82, Loss: 0.014245\n", "Epoch: 83, Loss: 0.018538\n", "Epoch: 84, Loss: 0.044791\n", "Epoch: 85, Loss: 0.034532\n", "Epoch: 86, Loss: 0.010175\n", "Epoch: 87, Loss: 0.021837\n", "Epoch: 88, Loss: 0.005545\n", "Epoch: 89, Loss: 0.012682\n", "Epoch: 90, Loss: 0.026414\n", "Epoch: 91, Loss: 0.021372\n", "Epoch: 92, Loss: 0.025901\n", "Epoch: 93, Loss: 0.025262\n", "Epoch: 94, Loss: 0.047044\n", "Epoch: 95, Loss: 0.016064\n", "Epoch: 96, Loss: 0.059213\n", "Epoch: 97, Loss: 0.017386\n", "Epoch: 98, Loss: 0.016215\n", "Epoch: 99, Loss: 0.016987\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "\n", "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n", " shuffle=True)\n", "\n", "model = nn.Sequential(\n", " nn.Linear(3072, 128),\n", " nn.Tanh(),\n", " nn.Linear(128, 2),\n", " nn.LogSoftmax(dim=1))\n", "\n", "learning_rate = 1e-2\n", "\n", "optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n", "\n", "loss_fn = nn.NLLLoss()\n", "\n", "n_epochs = 100\n", "\n", "for epoch in range(n_epochs):\n", " for imgs, labels in train_loader:\n", " outputs = model(imgs.view(imgs.shape[0], -1))\n", " loss = loss_fn(outputs, labels)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " print(\"Epoch: %d, Loss: %f\" % (epoch, float(loss)))" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0, Loss: 0.732168\n", "Epoch: 1, Loss: 0.348352\n", "Epoch: 2, Loss: 0.318960\n", "Epoch: 3, Loss: 0.313264\n", "Epoch: 4, Loss: 0.378358\n", "Epoch: 5, Loss: 0.276529\n", "Epoch: 6, Loss: 0.443889\n", "Epoch: 7, Loss: 0.436946\n", "Epoch: 8, Loss: 0.324288\n", "Epoch: 9, Loss: 0.274647\n", "Epoch: 10, Loss: 0.291681\n", "Epoch: 11, Loss: 0.242894\n", "Epoch: 12, Loss: 0.301849\n", "Epoch: 13, Loss: 0.202063\n", "Epoch: 14, Loss: 0.389276\n", "Epoch: 15, Loss: 0.167129\n", "Epoch: 16, Loss: 0.135282\n", "Epoch: 17, Loss: 0.385485\n", "Epoch: 18, Loss: 0.453852\n", "Epoch: 19, Loss: 0.641304\n", "Epoch: 20, Loss: 0.287667\n", "Epoch: 21, Loss: 0.337029\n", "Epoch: 22, Loss: 0.393282\n", "Epoch: 23, Loss: 0.409480\n", "Epoch: 24, Loss: 0.138473\n", "Epoch: 25, Loss: 0.690729\n", "Epoch: 26, Loss: 0.572156\n", "Epoch: 27, Loss: 0.078534\n", "Epoch: 28, Loss: 0.324833\n", "Epoch: 29, Loss: 0.262829\n", "Epoch: 30, Loss: 0.430449\n", "Epoch: 31, Loss: 0.071872\n", "Epoch: 32, Loss: 0.058039\n", "Epoch: 33, Loss: 0.052903\n", "Epoch: 34, Loss: 0.065879\n", "Epoch: 35, Loss: 0.107696\n", "Epoch: 36, Loss: 0.305224\n", "Epoch: 37, Loss: 0.098637\n", "Epoch: 38, Loss: 0.139823\n", "Epoch: 39, Loss: 0.226455\n", "Epoch: 40, Loss: 0.117763\n", "Epoch: 41, Loss: 0.106498\n", "Epoch: 42, Loss: 0.086254\n", "Epoch: 43, Loss: 0.135652\n", "Epoch: 44, Loss: 0.070890\n", "Epoch: 45, Loss: 0.304346\n", "Epoch: 46, Loss: 0.016917\n", "Epoch: 47, Loss: 0.057929\n", "Epoch: 48, Loss: 0.131021\n", "Epoch: 49, Loss: 0.136299\n", "Epoch: 50, Loss: 0.048885\n", "Epoch: 51, Loss: 0.241048\n", "Epoch: 52, Loss: 0.092595\n", "Epoch: 53, Loss: 0.059137\n", "Epoch: 54, Loss: 0.047421\n", "Epoch: 55, Loss: 0.102036\n", "Epoch: 56, Loss: 0.023338\n", "Epoch: 57, Loss: 0.054306\n", "Epoch: 58, Loss: 0.073878\n", "Epoch: 59, Loss: 0.031387\n", "Epoch: 60, Loss: 0.039865\n", "Epoch: 61, Loss: 0.022344\n", "Epoch: 62, Loss: 0.052310\n", "Epoch: 63, Loss: 0.059688\n", "Epoch: 64, Loss: 0.023977\n", "Epoch: 65, Loss: 0.010632\n", "Epoch: 66, Loss: 0.039090\n", "Epoch: 67, Loss: 0.080844\n", "Epoch: 68, Loss: 0.029650\n", "Epoch: 69, Loss: 0.027038\n", "Epoch: 70, Loss: 0.028515\n", "Epoch: 71, Loss: 0.021998\n", "Epoch: 72, Loss: 0.014992\n", "Epoch: 73, Loss: 0.019659\n", "Epoch: 74, Loss: 0.025150\n", "Epoch: 75, Loss: 0.017384\n", "Epoch: 76, Loss: 0.013249\n", "Epoch: 77, Loss: 0.009451\n", "Epoch: 78, Loss: 0.034637\n", "Epoch: 79, Loss: 0.114242\n", "Epoch: 80, Loss: 0.019007\n", "Epoch: 81, Loss: 0.016319\n", "Epoch: 82, Loss: 0.027428\n", "Epoch: 83, Loss: 0.022366\n", "Epoch: 84, Loss: 0.022583\n", "Epoch: 85, Loss: 0.006275\n", "Epoch: 86, Loss: 0.011964\n", "Epoch: 87, Loss: 0.018711\n", "Epoch: 88, Loss: 0.019636\n", "Epoch: 89, Loss: 0.018975\n", "Epoch: 90, Loss: 0.023520\n", "Epoch: 91, Loss: 0.016398\n", "Epoch: 92, Loss: 0.006638\n", "Epoch: 93, Loss: 0.013305\n", "Epoch: 94, Loss: 0.017126\n", "Epoch: 95, Loss: 0.021641\n", "Epoch: 96, Loss: 0.036945\n", "Epoch: 97, Loss: 0.004735\n", "Epoch: 98, Loss: 0.016781\n", "Epoch: 99, Loss: 0.012039\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "\n", "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n", " shuffle=True)\n", "\n", "model = nn.Sequential(\n", " nn.Linear(3072, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 2),\n", " nn.LogSoftmax(dim=1))\n", "\n", "learning_rate = 1e-2\n", "\n", "optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n", "\n", "loss_fn = nn.NLLLoss()\n", "\n", "n_epochs = 100\n", "\n", "for epoch in range(n_epochs):\n", " for imgs, labels in train_loader:\n", " outputs = model(imgs.view(imgs.shape[0], -1))\n", " loss = loss_fn(outputs, labels)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " print(\"Epoch: %d, Loss: %f\" % (epoch, float(loss)))" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.997700\n" ] } ], "source": [ "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n", " shuffle=False)\n", "\n", "correct = 0\n", "total = 0\n", "\n", "with torch.no_grad():\n", " for imgs, labels in train_loader:\n", " outputs = model(imgs.view(imgs.shape[0], -1))\n", " _, predicted = torch.max(outputs, dim=1)\n", " total += labels.shape[0]\n", " correct += int((predicted == labels).sum())\n", " \n", "print(\"Accuracy: %f\" % (correct / total))" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.821000\n" ] } ], "source": [ "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64,\n", " shuffle=False)\n", "\n", "correct = 0\n", "total = 0\n", "\n", "with torch.no_grad():\n", " for imgs, labels in val_loader:\n", " outputs = model(imgs.view(imgs.shape[0], -1))\n", " _, predicted = torch.max(outputs, dim=1)\n", " total += labels.shape[0]\n", " correct += int((predicted == labels).sum())\n", " \n", "print(\"Accuracy: %f\" % (correct / total))" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " nn.Linear(3072, 1024),\n", " nn.Tanh(),\n", " nn.Linear(1024, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 128),\n", " nn.Tanh(),\n", " nn.Linear(128, 2),\n", " nn.LogSoftmax(dim=1))" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " nn.Linear(3072, 1024),\n", " nn.Tanh(),\n", " nn.Linear(1024, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 128),\n", " nn.Tanh(),\n", " nn.Linear(128, 2))\n", "\n", "loss_fn = nn.CrossEntropyLoss()" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0, Loss: 0.641261\n", "Epoch: 1, Loss: 0.525149\n", "Epoch: 2, Loss: 0.466143\n", "Epoch: 3, Loss: 0.451913\n", "Epoch: 4, Loss: 0.343860\n", "Epoch: 5, Loss: 0.309738\n", "Epoch: 6, Loss: 0.485261\n", "Epoch: 7, Loss: 0.283789\n", "Epoch: 8, Loss: 0.301561\n", "Epoch: 9, Loss: 0.408200\n", "Epoch: 10, Loss: 0.346715\n", "Epoch: 11, Loss: 0.358134\n", "Epoch: 12, Loss: 0.388485\n", "Epoch: 13, Loss: 0.378096\n", "Epoch: 14, Loss: 0.518019\n", "Epoch: 15, Loss: 0.359279\n", "Epoch: 16, Loss: 0.420371\n", "Epoch: 17, Loss: 0.366249\n", "Epoch: 18, Loss: 0.282639\n", "Epoch: 19, Loss: 0.468854\n", "Epoch: 20, Loss: 0.467920\n", "Epoch: 21, Loss: 0.237441\n", "Epoch: 22, Loss: 0.243472\n", "Epoch: 23, Loss: 0.566929\n", "Epoch: 24, Loss: 0.316143\n", "Epoch: 25, Loss: 0.336322\n", "Epoch: 26, Loss: 0.473064\n", "Epoch: 27, Loss: 0.407040\n", "Epoch: 28, Loss: 0.252989\n", "Epoch: 29, Loss: 0.195740\n", "Epoch: 30, Loss: 0.663084\n", "Epoch: 31, Loss: 0.659899\n", "Epoch: 32, Loss: 0.285113\n", "Epoch: 33, Loss: 0.212042\n", "Epoch: 34, Loss: 0.324017\n", "Epoch: 35, Loss: 0.097063\n", "Epoch: 36, Loss: 0.181754\n", "Epoch: 37, Loss: 0.091362\n", "Epoch: 38, Loss: 0.069348\n", "Epoch: 39, Loss: 0.085656\n", "Epoch: 40, Loss: 0.163399\n", "Epoch: 41, Loss: 0.064912\n", "Epoch: 42, Loss: 0.046740\n", "Epoch: 43, Loss: 0.029891\n", "Epoch: 44, Loss: 0.018157\n", "Epoch: 45, Loss: 0.103532\n", "Epoch: 46, Loss: 0.161911\n", "Epoch: 47, Loss: 0.238185\n", "Epoch: 48, Loss: 0.081116\n", "Epoch: 49, Loss: 0.040988\n", "Epoch: 50, Loss: 0.008668\n", "Epoch: 51, Loss: 0.012557\n", "Epoch: 52, Loss: 0.015967\n", "Epoch: 53, Loss: 0.020964\n", "Epoch: 54, Loss: 0.023478\n", "Epoch: 55, Loss: 0.012850\n", "Epoch: 56, Loss: 0.054703\n", "Epoch: 57, Loss: 0.014922\n", "Epoch: 58, Loss: 0.045488\n", "Epoch: 59, Loss: 0.122221\n", "Epoch: 60, Loss: 0.028012\n", "Epoch: 61, Loss: 0.029533\n", "Epoch: 62, Loss: 0.004758\n", "Epoch: 63, Loss: 0.080409\n", "Epoch: 64, Loss: 0.005409\n", "Epoch: 65, Loss: 0.020399\n", "Epoch: 66, Loss: 0.008184\n", "Epoch: 67, Loss: 0.013888\n", "Epoch: 68, Loss: 0.002199\n", "Epoch: 69, Loss: 0.001918\n", "Epoch: 70, Loss: 0.018765\n", "Epoch: 71, Loss: 0.004223\n", "Epoch: 72, Loss: 0.001795\n", "Epoch: 73, Loss: 0.102238\n", "Epoch: 74, Loss: 0.002482\n", "Epoch: 75, Loss: 0.005807\n", "Epoch: 76, Loss: 0.001742\n", "Epoch: 77, Loss: 0.012760\n", "Epoch: 78, Loss: 0.017469\n", "Epoch: 79, Loss: 0.002849\n", "Epoch: 80, Loss: 0.001452\n", "Epoch: 81, Loss: 0.002740\n", "Epoch: 82, Loss: 0.003317\n", "Epoch: 83, Loss: 0.002066\n", "Epoch: 84, Loss: 0.001952\n", "Epoch: 85, Loss: 0.010757\n", "Epoch: 86, Loss: 0.004866\n", "Epoch: 87, Loss: 0.003957\n", "Epoch: 88, Loss: 0.001295\n", "Epoch: 89, Loss: 0.004410\n", "Epoch: 90, Loss: 0.002952\n", "Epoch: 91, Loss: 0.000676\n", "Epoch: 92, Loss: 0.001835\n", "Epoch: 93, Loss: 0.000739\n", "Epoch: 94, Loss: 0.001102\n", "Epoch: 95, Loss: 0.000792\n", "Epoch: 96, Loss: 0.000515\n", "Epoch: 97, Loss: 0.001548\n", "Epoch: 98, Loss: 0.026913\n", "Epoch: 99, Loss: 0.000140\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "\n", "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n", " shuffle=True)\n", "\n", "model = nn.Sequential(\n", " nn.Linear(3072, 1024),\n", " nn.Tanh(),\n", " nn.Linear(1024, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 128),\n", " nn.Tanh(),\n", " nn.Linear(128, 2))\n", "\n", "learning_rate = 1e-2\n", "\n", "optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n", "\n", "loss_fn = nn.CrossEntropyLoss()\n", "\n", "n_epochs = 100\n", "\n", "for epoch in range(n_epochs):\n", " for imgs, labels in train_loader:\n", " outputs = model(imgs.view(imgs.shape[0], -1))\n", " loss = loss_fn(outputs, labels)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " print(\"Epoch: %d, Loss: %f\" % (epoch, float(loss)))" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.999700\n" ] } ], "source": [ "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n", " shuffle=False)\n", "\n", "correct = 0\n", "total = 0\n", "\n", "with torch.no_grad():\n", " for imgs, labels in train_loader:\n", " outputs = model(imgs.view(imgs.shape[0], -1))\n", " _, predicted = torch.max(outputs, dim=1)\n", " total += labels.shape[0]\n", " correct += int((predicted == labels).sum())\n", " \n", "print(\"Accuracy: %f\" % (correct / total))" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.801000\n" ] } ], "source": [ "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64,\n", " shuffle=False)\n", "\n", "correct = 0\n", "total = 0\n", "\n", "with torch.no_grad():\n", " for imgs, labels in val_loader:\n", " outputs = model(imgs.view(imgs.shape[0], -1))\n", " _, predicted = torch.max(outputs, dim=1)\n", " total += labels.shape[0]\n", " correct += int((predicted == labels).sum())\n", " \n", "print(\"Accuracy: %f\" % (correct / total))" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3737474" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum([p.numel() for p in model.parameters()])" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3737474" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum([p.numel() for p in model.parameters() if p.requires_grad == True])" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1574402" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "first_model = nn.Sequential(\n", " nn.Linear(3072, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 2),\n", " nn.LogSoftmax(dim=1))\n", "\n", "sum([p.numel() for p in first_model.parameters()])" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1573376" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum([p.numel() for p in nn.Linear(3072, 512).parameters()])" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3146752" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum([p.numel() for p in nn.Linear(3072, 1024).parameters()])" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1024, 3072]), torch.Size([1024]))" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "linear = nn.Linear(3072, 1024)\n", "\n", "linear.weight.shape, linear.bias.shape" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "conv = nn.Conv2d(3, 16, kernel_size=3)" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([16, 3, 3, 3])" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "conv.weight.shape" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([16])" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "conv.bias.shape" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "img, _ = cifar2[0]\n", "\n", "output = conv(img.unsqueeze(0))" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1, 3, 32, 32]), torch.Size([1, 16, 30, 30]))" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img.unsqueeze(0).shape, output.shape" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAGXFJREFUeJztnXt4VeWVxt8lAUEDRQhICtiAYos3Lg14G6yCqPg4g1i1OtWh1hHt6Iyd6Uyljq20tR3tU7XaUgWrI/YRFeu99YYpLVQZJCImYqxcREECJGKEiBhD1vxxDhriXisn+5yzT9Lv/T1PniTrzbf3OvucN/ucvfa3PlFVEELCY59CJ0AIKQw0PyGBQvMTEig0PyGBQvMTEig0PyGBQvMTEig0PyGBQvMTEihF2QwWkdMA3AKgG4DfqOr13t/3LhHtXxatvfWGM3Df6PA+Pe0h3aWbqfXaz37YBxSXmFofDIyMFzn/Qxux3dQ27lhjasW97TsvP2cqQHcjvssZYxxeAP7Zwbs3dLcR398Zkw8ajHiTM+ZD8ygC3qPevqPZ1Jo+dDa509Es3jfiuwFtUclkExL39l4R6QbgDQCTAWwEsBzA+ar6mjWmrFz0+5XR2j+f4uxsWHS4z0jbxAOKbIuMPqq/qU2bcLGpTZbLI+MDnZf0C3jO1L5bcYapHTvpI1OzRwEDjPhqZ4xxeAEAxY7m/UNpNOLjnTFxaXG0J4z4emdMDQabWjNsgz9XscXU3qpxdviyo1k8acTrAf04M/Nn87Z/PIA1qrpOVZsA3A9gahbbI4QkSDbmHwxgQ6vfN6ZjhJAuQDbmj3pr8ZnPECIyQ0QqRaRyR10WeyOE5JRszL8RwNBWvw8BsKntH6nqXFUtV9Xy3tYHUkJI4mRj/uUARojIMBHpAeA8AI/nJi1CSL6JXepT1WYRuQLAM0iV+u5S1VXemG5wrh5/1Rl4WXR4+0j7yuv2I981tTc329qaF+fY2sTow3X+2JPMMcc5hblfT+praqthXzk2CiYA7CvwQ5wx9lEE6h3Nzj7uVf3hjnaUqVRiuan9z2PvRMaLh0aGU5TYj7riVrsK02OMs007Rbse6WE90R0o3mVV51fVJ2EXHQghnRje4UdIoND8hAQKzU9IoND8hAQKzU9IoGR1tb+jbAHwS0ObdKk9rsKoD44a49VW7HLeK99729YeX2drU74TGa+eZZehJo+vMjWvwuNMWMRGR7MqSlOcMYMc7VhH64MDHdU6/l5h0S6jvYAzTW3hY3Y5ddmZ86KFs+wsjv6VnQeOtKWmF20Nbzqa5cJFzpgcwDM/IYFC8xMSKDQ/IYFC8xMSKDQ/IYGS6NX+D94Bnr8mWjv4Onvc5V+Pjs9+yOl/5LVNGudo3rzEp6LDSy+3r+j/vbM572r/TY7mMdmIe9fYvTZefdz+LNE9DQHgMuPRTcah5phxTjfBOqehWPXQC0wNMK72OwfkS6W21jDB1v7qXdF3tlmo2TE88xMSKDQ/IYFC8xMSKDQ/IYFC8xMSKDQ/IYGSaKkPmwH8JFpa6wyb/Q+G4M1+cRrMHeysDrTWKxHOjw5vchrdfWOxsz0n/4Exl7axFhuzSoAAcKi7AJi97NkfPtus+RN2GU/AMNiTsZbC7oV4ntfkcawtmY+8bKE5YqG55hSwabazK2/G1QZHs5Y3yjM88xMSKDQ/IYFC8xMSKDQ/IYFC8xMSKDQ/IYGSValPRNYD2AFgN4BmVS2PvbF7Hc0qvzn91PBNW2p2lmo6+le2tsxaImm1k4eHM4Ow8WZb+5eDbM1aCHmJk0Y13je1Mkdb6WxznNHf7z38yRxzH86xNxi1JnRGnB0dbt7fHLFp9qP25ryZewc4WidcoToXdf6TVNVb0o0Q0gnh235CAiVb8yuAZ0XkJRGZkYuECCHJkO3b/uNVdZOIDASwUEReV9W9bmhN/1PgPwZCOhlZnflVdVP6+1YAjyBiWXZVnauq5VldDCSE5JzY5heR/UWk956fAZwC4NVcJUYIyS+iqvEGigxH6mwPpD4+zFdVY87eJ2Pi7ezHRvweZ4y3zpQzm+6Lc2ztEiN+hLOreqcB5owX3zG1ndX2NkddbGvWBDFv1uTxjvYtR7NmEAJAKaLrkdXYbY65oMqpYY5yOmfiZ44WA++BTXQ0u8cosNTRrBl/MWf7qWpGhdHYn/lVdR2AUXHHE0IKC0t9hAQKzU9IoND8hAQKzU9IoND8hARK7FJfrJ3FLfVNMeK9nTHeTLv3HM1qFgoAJxhxZ+2/rzrVK6fHKO5c4YheM0ij8efhzlpxzlw6t4zpTI5EmRGvQX9zzIkVh9kbPNmaUgkAy23JKr+NdDZ3rqN5DV5rHc1Y5zEfZFrq45mfkECh+QkJFJqfkECh+QkJFJqfkEBJdrkuDy+TNUb8n2Lua4GjPexo1gpPZfaQhy53tucsybWPvaoVBjnLU40w4hc6aRziaB5eWzpLa8a79qCFh8ZL5G7nar/B1Om2NsgZN+dSR7RXAOuU8MxPSKDQ/IQECs1PSKDQ/IQECs1PSKDQ/IQESucp9TU72g4j7k2y8HB6+LlHZLIR9yYYbXa0eU4aV9rahO7ONg28JZWsh9UecQ6/Nz0HN9iTfuD0Qrx9+mmmdgiejox7x2O9o7kDvddwJ4RnfkICheYnJFBofkICheYnJFBofkICheYnJFDaLfWJyF0AzgCwVVWPSMf6AXgAqfls6wGcq6peZ7zssMpl850xzqw4c+ob4Pf+O8CIe7MLvRmEznJMTU6fvvXDbW2IEe+JA80xT2GLqXl9Bp9wNGsipo+3N3u9qzKnh5/1lHn5NcOeXTjqyjdM7ZUjnY3+0NGs16pXkh5gxP/sjGlDJmf+uwG0LaTOBFChqiMAVKR/J4R0Ido1v6ouBrCtTXgqPr1FZR6AM3OcFyEkz8T9zH+gqtYCQPr7wNylRAhJgrzf3isiMwDMyPd+CCEdI+6Zf4uIlAJA+vtW6w9Vda6qlqtqecx9EULyQFzzPw5gTxe06QAey006hJCkaHe5LhG5D8CJAEoAbAFwLYBHkSpiHQTgbQDnqGrbi4JR20pubTAPr0OjNwvPKuV4H2p6OZoznc5b5usc7O9sNHp9qhKn1FePKlP7P2dPv/jQEW804vc4Y1b/xtZG2p1Qv/baR6Y2wYh/CUeZY8bhFlNrxhxTK4L9pC3Hwaa22Tj+jbDLiq/r2sj4/HEbsaXyo4yW62r3M7+qnm9IkzLZASGkc8I7/AgJFJqfkECh+QkJFJqfkECh+QkJlM7TwLOz4M2kqjbi33PG/NyWpjvlPGs2GgCsh93ossQoKRU5T/VKZ1+/8O7g+KOjWY0uvVmTWGdLk+0npgF2qc+q3BY75c0ifNfUSvG2qR1q1jeBSfi6qVmtULfhHXNEPzk5Mr4Emd9LxzM/IYFC8xMSKDQ/IYFC8xMSKDQ/IYFC8xMSKF271Odl762b1uBo7mJyBk4jTr+0Zdf6ijDN2Z3dKXKIMVutHu+aY5asMCVg6UJby/m6dT81laMP2NfU/t3ZopWi18BzidMQ1Ht5/AgXmNpwZ5vA5yOjy/GBOeJUnORsLzN45ickUGh+QgKF5ickUGh+QgKF5ickULr21f5YV5QR74p+XKJb6qWk9+yr2x/uOs/UDinpZm/UeEaLnIrEJWPbLsj0KReNtcetsZs2o/qZ6Ek6f1hwg71BPGoqE5rtyTunftJL9rPc+MnaMnsTZyUsAKh1tDcdbYjTF9B6arz+iZUf3h0Zr23xmlDuDc/8hAQKzU9IoND8hAQKzU9IoND8hAQKzU9IoLRb6hORuwCcAWCrqh6Rjs0CcAmAuvSfXa2qT+YryU6PU87r1/gDU3twtr2E04C+djmvYYS9v0aj0rNmtV0qKxthT5rp2dfe14SJ9srsg46L1p4660pzTMvDdqlvqVNHe80o5wHAaCM+DIPNMRuc3nm9Hcs0w37O/tfpMzjEiE8xRwA9e0VP4Jq/z/vOqL3J5Mx/N4CoQvDNqjo6/RWu8QnporRrflVdDKDdRTgJIV2LbD7zXyEiVSJyl4h4naYJIZ2QuOa/DcDBSH2kqoW9IDNEZIaIVIpIZcx9EULyQCzzq+oWVd2tqi0A7gAw3vnbuaparqqZryZACMk7scwvIqWtfp0G4NXcpEMISYpMSn33ATgRQImIbARwLYATRWQ0AAWwHsCleczR5PNldolq2AnmmxEU7bIf9p8XLOp4IsP+w5S2vTnBHlf3liltHbG/qdVutstU26rfiBaqVpljVjXaveLQaJeOHho3xtR6jIkuY7Y87PQEdHjeWioNwK+dcb2NeJ1TzhvpbG+yM5W0r6N5bSOtjozj4c2AvCwy2gtfccbsTbvmV9XzI8J3ZrwHQkinhHf4ERIoND8hgULzExIoND8hgULzExIoiTbwHNz/8/jXqdElip4n2GWjnmMOi4yfNGy4OabYqvHAnYSHswZdYWoVt94fLVjlNQCofttJxC7nod5eQ2tb3YHOuOjGmXBKW0B/R3PW8lpiz1hsWmJt83POvhycUp/Xj/VpI772OmeQ16XTaWh66cW29hdnk1b+xzkNTe1EnLJtG3jmJyRQaH5CAoXmJyRQaH5CAoXmJyRQaH5CAiXRUt+gslJcdef3k9xlh6mpdxa1w7tG/PfxdubtqsYrv51tS32PjY43OGVFOOVIZz0+H+tYWfH4eGvamS9w75XvTRN0pvzNcZqdmlP3AKwaFh1/ovtSc8xPMDkyvt1JoS088xMSKDQ/IYFC8xMSKDQ/IYFC8xMSKIle7e8K1Fd7kymSxLsqPseWGqw+cnZ/OcCYsNSZcF6pqx5zxp0QHf7yTHvISxuc7RnLoQEAvHGnd3zcSxvtIbcZj6sjtRme+QkJFJqfkECh+QkJFJqfkECh+QkJFJqfkEDJZLmuoQDuATAIQAuAuap6i4j0A/AAgDKkluw6V1Xfy1+qHaMJm0yth1NGK6q2l6dqyiqjpPgbXUxphqMNdTSjwlntvFK/6PT367vD1mqcMmDPXra21eg3ebjd1hK7PoyOa4s9pi2ZnPmbAXxHVUcCOAbA5SJyGICZACpUdQSAivTvhJAuQrvmV9VaVV2R/nkHgBoAgwFMBTAv/WfzAJyZryQJIbmnQ5/5RaQMwBgAywAcqKq1QOofBICBuU6OEJI/Mja/iBQDeAjAt1U1454BIjJDRCpFpLKuri5OjoSQPJCR+UWkO1LGv1dVH06Ht4hIaVovhXFbsarOVdVyVS0fMGBALnImhOSAds0vIoLUJeQaVb2plfQ4gOnpn6cD8KZXEEI6GZnM6jsewIUAqkVkZTp2NYDrASwQkYuRagJ3Tn5SBLYZ8UZYS1MBDfqcqQ3CFlPbmWlSJFGOnm1ry56xtT7GqlbeC7/WaWl40UFHmdq0g6pMzZtTeY0x7JhJ9hirJeBrHbiK1675VfUvAMSQnfQIIZ0Z3uFHSKDQ/IQECs1PSKDQ/IQECs1PSKB0iQae/Yx4MYabYzb/8R1Te6p+iantV2znsdNbXotkz5SY4162pQNOjY5700+nHWRr52BfU+vpbHORox0/MTruTVa874Xo+LYOvEZ55ickUGh+QgKF5ickUGh+QgKF5ickUGh+QgKlS5T64lAybLCplU20OyOOqbbLgM//JHpu1pevsvN4yZb82tBqR5vvbTRBjnW0pTG2d40tTcbnTG30TPtlvMZo1rpc7X3tsqaxAbgJy03Nq1Q6y+5hgrG/OifHDW9Gx5s60GWWZ35CAoXmJyRQaH5CAoXmJyRQaH5CAiXRq/0tsHvkNRrLDwFAX2OpoyJ8YI4ZPtye9NO4Y7GpWVf0PWrmOOLpjuZ1Mh/R4TSSpyHGmCGO5iyFdd1Eexk1jHS2aVQQ9nEmcD1gXEkHADgTZ54+ztZOczY5wYg3OFWHhrOi40/d6OyoDTzzExIoND8hgULzExIoND8hgULzExIoND8hgdJuqU9EhgK4B8AgpKp1c1X1FhGZBeASfFqwulpVn/S2tQ+A/Qyt3ikb9TBKfVvxe3PMgw+cZ2pX2JL737DFiO/0Sl5xJ+EsjDkuSTpeFQWM5xIA8A1H2+xoXoO88dHhFq+Jnzcp6VxbWvtzW5ttzxfDGGOVy4tgT06r7hXdo7JbLpfrQuop/o6qrhCR3gBeEpE9L82bVdV5yISQzkoma/XVAqhN/7xDRGoA518SIaRL0KHP/CJSBmAMgGXp0BUiUiUid4nIATnOjRCSRzI2v4gUA3gIwLdVdTuA2wAcDGA0Uu8MIm8sFJEZIlIpIpV1dd79rISQJMnI/CLSHSnj36uqDwOAqm5R1d2q2gLgDhiXVlR1rqqWq2r5gAEDcpU3ISRL2jW/iAiAOwHUqOpNreKlrf5sGoBXc58eISRfZHK1/3gAFwKoFpGV6djVAM4XkdEAFMB6AJe2t6FG7MILqInUajesNcdVvxYd/+0iu2b3wLPtZRONVc7rVPybo92a431da0s9jrS1prMNwetNGBen/GbOFKx3xixwNK+YHXM5t18aM1qPNMp5AHBHVXT8Y2d2bFsyudr/FwBRkwvdmj4hpHPDO/wICRSan5BAofkJCRSan5BAofkJCZREG3hu/7geC2vvjtSqX7jXHNe4OrrksehlZ2deE8YuzqSv2VpFrkt982ypyZvNOM6I26tdxWeYow014t1j7itmOc9ryPqK8Tp+xGkIWmI8rroemafEMz8hgULzExIoND8hgULzExIoND8hgULzExIoiZb6PvpgG1a/GF3SK+ptz2AqMZowTnDWaKv4L1vrY0vY7mgWp06xtWeeirFBAJOsUhmAMWNsrcKa8Re3BLje0fo6mlXa8pp+eqVbjziNRE9ytH90tLgNWb3ZjMZaj9c7axeOOjU6/l63jDPimZ+QUKH5CQkUmp+QQKH5CQkUmp+QQKH5CQmUZEt973+MNU9Gl/SKrdlXADYaWQ5yymFTH7W1eqd5Y4OTx67F0fGlccs/DhXO7LeKmc5Aozv6frfbQ3bOcrbnNMc8/Ou2dohRnu3p7OoRY806AGjyOkYOcTTrud7ljHFKyHnBKnE6B6vamMnY4j2uNvDMT0ig0PyEBArNT0ig0PyEBArNT0igtHu1X0R6AlgMYN/03/9OVa8VkWEA7gfQD8AKABeqapO3rf49gYuMCR+vO1fZrfkjzc6i4IPG2trmFbZW41y5b4lch7gAeH3k7okO73T67X35x7a2wVlYedUNjnZKdHw/Z1LST6faWo2jPae29pa19FatPQaljuZUmGL393uv40OKjErAxx04nWfypx8BmKiqo5Bajvs0ETkGwA0AblbVEUilf3HmuyWEFJp2za8p9vxP657+UgATAfwuHZ8H4My8ZEgIyQsZvUkQkW7pFXq3AlgIYC2ABlXdM5N6I4DB+UmREJIPMjK/qu5W1dFI3Us1HtH3QEV+8hKRGSJSKSKVjXE/ExFCck6HrvaragOAPwE4BkBfEdlzwXAIgE3GmLmqWq6q5cXF2aRKCMkl7ZpfRAaISN/0z70AnAygBsAiAGen/2w6AOfObEJIZ0NUnToJABE5CqkLet2Q+mexQFV/JCLD8Wmp72UAF6jqR962RpeKPmvUBDZeta857r4bozc735ns0eBMzujrlHkaFtraTlvKPSWO5pWbYvYMNJngaM4kkj5GqW+7s4zaF75pa2dMsjVjLhMA4NZ10fFttziDnDIxnNKnu0Scl+TDRtzrTWhN1JoB6OsqzshPaLfOr6pVAD5TnVXVdUh9/ieEdEF4hx8hgULzExIoND8hgULzExIoND8hgdJuqS+nOxOpA/BW+tcS2B3WkoR57A3z2JuulscXVNUrLH5Coubfa8cilapaXpCdMw/mwTz4tp+QUKH5CQmUQpp/bgH33RrmsTfMY2/+ZvMo2Gd+Qkhh4dt+QgKlIOYXkdNE5K8iskZEvMWn8p3HehGpFpGVIlKZ4H7vEpGtIvJqq1g/EVkoIqvT3532pHnNY5aIvJM+JitF5PQE8hgqIotEpEZEVonIlel4osfEySPRYyIiPUXkRRF5JZ3HD9PxYSKyLH08HhCRHlntSFUT/UJqavBaAMMB9ADwCoDDks4jnct6ACUF2O8JSE0cfbVV7GcAZqZ/ngnghgLlMQvAfyZ8PEoBjE3/3BvAGwAOS/qYOHkkekwACIDi9M/dASxDqoHOAgDnpeO3A/hWNvspxJl/PIA1qrpOU62+7wfgNGb+20NVFwPY1iY8Fam+CUBCDVGNPBJHVWtVdUX65x1INYsZjISPiZNHomiKvDfNLYT5BwPY0Or3Qjb/VADPishLIjKjQDns4UBVrQVSL0IAAwuYyxUiUpX+WJD3jx+tEZEypPpHLEMBj0mbPICEj0kSTXMLYf6oLiOFKjkcr6pjAUwBcLmInFCgPDoTtwE4GKk1GmoBJLZUiYgUA3gIwLdVdXtS+80gj8SPiWbRNDdTCmH+jQBar89jNv/MN6q6Kf19K4BHUNjORFtEpBQA0t+3FiIJVd2SfuG1ALgDCR0TEemOlOHuVdU9ja0SPyZReRTqmKT33eGmuZlSCPMvBzAifeWyB4DzADyedBIisr+I9N7zM4BTALzqj8orjyPVCBUoYEPUPWZLMw0JHBMREQB3AqhR1ZtaSYkeEyuPpI9JYk1zk7qC2eZq5ulIXUldC+C/C5TDcKQqDa8AWJVkHgDuQ+rt48dIvRO6GEB/ABUAVqe/9ytQHr8FUA2gCinzlSaQx98h9Ra2CsDK9NfpSR8TJ49EjwmAo5BqiluF1D+aH7R6zb4IYA2ABwHsm81+eIcfIYHCO/wICRSan5BAofkJCRSan5BAofkJCRSan5BAofkJCRSan5BA+X+oAC6reFaYfAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.imshow(img.permute(1, 2, 0), cmap='gray')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAF4lJREFUeJztnVuMndV5hp/P40NssMHHsfEB24QQCEkxmqAqVBVVDqJVJBJFieKLiEpRnIsgNVIuGnETbiqhKofmAkVyCgqRckIKabhAbSJUCRIpFoYQsLExjuMTHp9iO8YGYmN/vZjtdurMetfMnvHeg9b7SMgz+5v/X2uv/b/8e+93fd8XmYkxpj1m9HsCxpj+YPEb0ygWvzGNYvEb0ygWvzGNYvEb0ygWvzGNYvEb0ygWvzGNMnMyB0fE3cC3gQHg3zPzQfX3c+fOzQULFnQ11ltvvVWMzZ49uxibNWtWMXbx4kU55rlz54qxP//5z8XYhQsXirGBgQE5Zm1OJdROzdo5VVzFartDZ8zo7t4SEcWYej0B5s2bV4yp60RdXzVmzuxORmp9atdJaY1OnTrF2bNnyws4iq7FHxEDwEPAR4GDwLMR8URmvlw6ZsGCBWzcuHHMWO1CevXVV4uxlStXFmPLly8vxt5880055v79+4uxPXv2FGNnzpwpxubPny/HfOONN4oxJQr1P6Pa8zx79mxXsfPnz8vzvutd7yrG1HNRIh0cHJRjbtiwoRhbs2ZNMbZ79255XsWiRYuKMXVdq2uhdp3MmTNnzMcfeughedxoJvO2/w5gd2buycxzwI+BeyZxPmNMD5mM+FcCB0b9frDzmDHmHcBkxD/W+7a/eI8TEZsiYmtEbK29/TTG9I7JiP8gsHrU76uAQ5f/UWZuzsyhzByaO3fuJIYzxkwlkxH/s8CNEbEuImYDnwWemJppGWOuNF1/25+Zb0fEfcB/MWL1PZKZ29UxZ8+eZcuWLWPG7rzzTjnejTfeWIypb44XLlxYjK1du1aOqb7p3rdvXzGmvu2v2ULq3ZE6Vh33+uuvyzFPnTpVjCnL6brrrpPnVbabcifUt/1XX321HFM5DMr52bVrVzF29OhROaY6r1q/q666qqsYlF/vmkU4mkn5/Jn5JPDkZM5hjOkP3uFnTKNY/MY0isVvTKNY/MY0isVvTKNY/MY0yqSsvokyMDBQ9N2V1ww6c0pl3/3xj38sxj70oQ/JMa+99tpiTPmw69atK8ZqKanKp1VrpLz8WvZdKUMMdHbZsmXL5HnVvgSVRafSvms+v3pd1PPsdn9A7Vg1XxVTc4Xy2qo9L5fjO78xjWLxG9MoFr8xjWLxG9MoFr8xjWLxG9MoPbX65syZww033DBmrJY2qSriqnTWgwcPFmMqLRe0laXsH1Xx9k9/+pMc8/jx48WYsn9UlSSV1gw69VbZdbWqwGq+qhDnBz/4wWLs9ttvl2OqNGxlsypbt5Ymq6rwqtRltbYTSc0dTa0Q7mh85zemUSx+YxrF4jemUSx+YxrF4jemUSx+Yxqlp1bfjBkzirbS4sWL5bHdNmBUFtivf/1rOeaSJUuKsQMHDhRjKgNx/fr1ckzVq081Dl29enUxVsuYHB4eLsaUdTaZpqPXX399MfaRj3ykGPvUpz4lx3z++eeLsdOnTxdjK1asKMZOnjwpx7wS1Cy7bhu6jsZ3fmMaxeI3plEsfmMaxeI3plEsfmMaxeI3plF6avXNnDmzaOmp7CfQVqAqXrlq1apiTFk/oItBKstOWUO1Ap633HJLMfbqq68WY8oifOutt+SYEyn6OJq3335bxtXrosa85pprirFaAU9l+6pCryqrVBWBBW0Jq8xQdVzN6lNZruNlUuKPiL3A68AF4O3MHJr0jIwxPWEq7vx/l5nlJHRjzLTEn/mNaZTJij+BX0TEcxGxaaw/iIhNEbE1IraqraLGmN4y2bf9d2bmoYhYBvwyInZm5tOj/yAzNwObAVatWjX+GkPGmCvKpO78mXmo8+9R4GfAHVMxKWPMladr8UfEVREx/9LPwMeAbVM1MWPMlWUyb/sHgZ91/NqZwA8z8z/VARFRrFha8/mV16r2ACivXlVdBe2dqyq8yld/5ZVX5JjK+1Xe7pEjR7o6DnSjyclUmFXjTqTK7Gh27twp4yqlVzUHVd9HqWsI9PWnmo6q/Qy1vQWlPRYTWdeuxZ+Ze4C/6vZ4Y0x/sdVnTKNY/MY0isVvTKNY/MY0isVvTKP0NKUXylZErXqvSg9V9oZKHa1VQFXHqtTSlStXFmOqUi7otF1lySmLUNmAoKv7KjuvZvWp10VVY1bnffHFF+WYjz/+eDGmGrqeOXOmGFu3bp0cUzVCVdeQGrNm2ZWqUrtRpzGmisVvTKNY/MY0isVvTKNY/MY0isVvTKNMm0adtcyp1157rRhTlonKqpo/f74cU1lgqgrv0qVLizFVXRZg3759xZjKklOWksrMA22tKXtMZaXVWLNmTTGmmmY+88wz8rzbtpWzyt/97ncXY2r9VGVf0E1blT07Z86cYqxmQ5esvolU9fWd35hGsfiNaRSL35hGsfiNaRSL35hGsfiNaZSeWn0RUbS6ak0fFefOnSvGzp8/X4zVMglVEcWS1QLaplFFQUE3D1XPpdsinKAtT5W9qJqVgrZvVRaiei41K0tZsMeOHevqvHv27JFjqqxJlfGn1kDZ11C2Z2tNWUfjO78xjWLxG9MoFr8xjWLxG9MoFr8xjWLxG9MoFr8xjVL1+SPiEeDjwNHMvLXz2CLgJ8BaYC/wmcw8WTtXZha96lqjTuWJnjhxohg7ebI8rWXLlskxVWrprl27irFum3iC9sbVsSr1tlZld/ny5cWY2gtx7bXXyvMuWrSoGBscHCzGDh06VIypfRC1MWtpsiVq16ZK+VXPZfXq1cWYSjGG8r4YteflcsZz5/8ecPdlj30VeCozbwSe6vxujHkHURV/Zj4NXH5rvQd4tPPzo8AnpnhexpgrTLef+Qczcxig82/x/XNEbIqIrRGxtbZl0RjTO674F36ZuTkzhzJzSO0TN8b0lm7FfyQiVgB0/tVFzowx045uxf8EcG/n53uBn0/NdIwxvWI8Vt+PgLuAJRFxEPga8CDwWER8HtgPfHo8gymrT6XPgk7zVFaWSmetfQxRzSTVfFWKp7IBxzOnEsrGqlVGVlVkVaxmW86YUb633HzzzcWYasq6Y8cOOaZ6zdS1oFKBa+t39uzZYuzw4cPF2MGDB4sxVQFaMZFGnVXxZ+bGQujD4x7FGDPt8A4/YxrF4jemUSx+YxrF4jemUSx+Yxql5406S5bdZHb/KStGWWA1q0pZQ2pMZTfVstIU6ryq+rGq+gu60aSyJieT1acy5ZRFuHPnTjmmWofafEvUqh+ra1ddJyoDr5aBWMo4rVUaHo3v/MY0isVvTKNY/MY0isVvTKNY/MY0isVvTKP03OorNYWsFZk8depUV2Mq66fWHFTZUSqmil6qQpugbS5l/6j1qxV17NZyqr1mqimpymhTBUVrlrAqrKpQmXkq4w9gwYIFxZjKilTrV2voWmoUO5Eipb7zG9MoFr8xjWLxG9MoFr8xjWLxG9MoFr8xjWLxG9MoPfX5z507V0wfVT4raF+428qr11xzjRxTVUJV3vhkmluq/QNqPt16zaCfi4rVKhErH1vtsVB7HWqv2YULF4oxtUdArVFtP0i366dez+HhYTlmqXp0yf8fC9/5jWkUi9+YRrH4jWkUi9+YRrH4jWkUi9+YRhlPo85HgI8DRzPz1s5jDwBfAI51/uz+zHyydq4333yT3/3ud2PGlEUDcOzYMRkvsWzZsmKsZrutXr26GFOVaVVVYGVjgbaGlHVWSpWuHQc6DVSdVzXUrMXVnJTdOXv2bDnmkiVLirGVK1cWY8oGrKXXHj1a7lAfEV2NWau4fPLkyTEfn+qU3u8Bd4/x+Lcy87bOf1XhG2OmF1XxZ+bTwIkezMUY00Mm85n/voh4MSIeiYiFUzYjY0xP6Fb83wFuAG4DhoFvlP4wIjZFxNaI2FrrkGOM6R1diT8zj2Tmhcy8CHwXuEP87ebMHMrMoVKrLmNM7+lK/BExulHYJ4FtUzMdY0yvGI/V9yPgLmBJRBwEvgbcFRG3AQnsBb44nsHOnTvHvn37xozVqtqqbCVlDanzKlsIYP369cWYqiasmnHWmj6q56nGVNZQzepTazQ4OFiMKesM4Pjx48VYLdOwhLJYgWIjWNCW3U033dTVfEBn56nXUz2XWsXgUpbroUOH5HGjqYo/MzeO8fDD4x7BGDMt8Q4/YxrF4jemUSx+YxrF4jemUSx+YxrF4jemUXpavffixYvFdNdahVQVV16q2lKsusGC9mHPnDlTjKlU4ZpPXarKCvVuuyVUajJob3zevHnFWK1jrjpWVWtW+yRq69dtCrJKy61dJ6qisFoDlX5b24NS2i+i9lZcju/8xjSKxW9Mo1j8xjSKxW9Mo1j8xjSKxW9Mo/TU6lPUrD6FquiqqgKrJp6gbRqVmqsqtr73ve+VYx45cqQYU804VdNM9Tygbp+VqBVnUVbWiRPlspCHDx8uxpQNCNrqU9fYnj17irFaRVz1eivLTtnFKk0dys+zlr49Gt/5jWkUi9+YRrH4jWkUi9+YRrH4jWkUi9+YRump1TdjxoyiPVRrYKnsFHVst5lloG1CZamoyrQqAwx0FV4VU+tTa4Kq1k/ZoTVbqdvMve3btxdjtaw1ZYeq9VNz3b9/vxxToWxotX41e7FklU7EMved35hGsfiNaRSL35hGsfiNaRSL35hGsfiNaZTxNOpcDXwfWA5cBDZn5rcjYhHwE2AtI806P5OZJ9W5MrNot9RsI5Wtpaw1Ze/89re/lWOqOalimsoe27VrlxxTzVdlgSlrqGYvquw8VdhSFcSszUlZa6qIac3KUsVTlWWnriG17gC7d+8uxlRWn4otW7ZMjqmyOMfLeO78bwNfycybgb8GvhQRtwBfBZ7KzBuBpzq/G2PeIVTFn5nDmfl85+fXgR3ASuAe4NHOnz0KfOJKTdIYM/VM6DN/RKwFNgBbgMHMHIaR/0EA+n2KMWZaMe7tvRFxNfBT4MuZeVptJ73suE3AJtBbHY0xvWVcd/6ImMWI8H+QmY93Hj4SESs68RXAmN8MZebmzBzKzKFZs2ZNxZyNMVNAVfwxcot/GNiRmd8cFXoCuLfz873Az6d+esaYK8V43vbfCXwOeCkiXug8dj/wIPBYRHwe2A98+spM0RhzJaiKPzN/BZQ+4H94ogOWviuopTCqtFTVjPONN94oxk6elNsSePnll4sx5cer+fzmN7+RY950003F2PXXX1+MKf+7trZqP4Oqanvo0CF5XtUgVFWnVT7/H/7wBzmm+l5JNXRV1ZjVHhPQDUvV81Re/YoVK+SYCxcuHPNxV+81xlSx+I1pFIvfmEax+I1pFIvfmEax+I1plJ5X7y3ZIsqGATh16lQxpuwdZTd94AMfkGOqCrN79+4txpQlp2xA0Bbi3LlzizGVYlxbW4WyWGtNM1UqrEoHXrVqVTG2dOlSOaZKX1bPRV1DmSnHXL9+fTGm0n3V66KalSpqtu5ofOc3plEsfmMaxeI3plEsfmMaxeI3plEsfmMapadW32RYvHhxMaasIWWP1aoR3XrrrcWYahipmkXWCpps27atGFu+fHkxpiyeWgUltQ7KrqtVkFUZbao56NDQUDGmKg0DbNmypasxVdXfmj2rrF113tdee60YqzUHLZ231pR1NL7zG9MoFr8xjWLxG9MoFr8xjWLxG9MoFr8xjdJTq29gYKBYDFE1twS47rrrijFlAyqb5sCBA3LM973vfcXY2rVrizFlca1Zs0aOuX379mJMFa9U86k11FQZiqqZqXqeoF9TVaRTFV1V1hno7EZl+6oiner6Am0hbtiwoRjbuXNnMfbYY4/JMRctWiTj48F3fmMaxeI3plEsfmMaxeI3plEsfmMaxeI3plEsfmMaperzR8Rq4PvAcuAisDkzvx0RDwBfAI51/vT+zHxSDjZzJkuWLBkzVqswq/z6s2fPFmPr1q0rxmpNH1VTTZUmqzzsWhqxarj57LPPFmPvec97ijHVhBK0r668epW6DNr/Vr66qtRcq6Sr9gGotVep1rfccoscUzV8Vdf14OBgMabSt6Gcaj2R6r3j2eTzNvCVzHw+IuYDz0XELzuxb2Xm18c9mjFm2jCeFt3DwHDn59cjYgew8kpPzBhzZZnQZ/6IWAtsAC6VS7kvIl6MiEciYsyG4RGxKSK2RsTWWkUUY0zvGLf4I+Jq4KfAlzPzNPAd4AbgNkbeGXxjrOMyc3NmDmXmUK0EkzGmd4xL/BExixHh/yAzHwfIzCOZeSEzLwLfBe64ctM0xkw1VfHHyFekDwM7MvObox5fMerPPgmUK08aY6Yd4/m2/07gc8BLEfFC57H7gY0RcRuQwF7gi7UTZWYxDbRWCbbU4BN0Gqc6r7K4AJ5++uliTDUAVRZYrSqrsvpOnDhRjKlKsGquoFNzVUytO+hKuso+W7hwzK+PAG2rARw9erQYUzbYkSNHirH3v//9ckxVPVqlSytuv/12GS/ZvjUreTTj+bb/V8BYZ5SevjFmeuMdfsY0isVvTKNY/MY0isVvTKNY/MY0Sk+r92Ym58+f7+pY1TBSWU67d+8uxmoNLAcGBoox9TxqFXoVqtKuep7PPfdcMaay5EA/T7V+qhouaNvt2LFjxZhqfFmzstTrojIC1drWLGFVOVlZzWoNatWsV6xYMebjyvK9HN/5jWkUi9+YRrH4jWkUi9+YRrH4jWkUi9+YRump1XfhwoVisU2VzQY6m2vOnDnF2Lx584ox1fwTdBFFZf8ou6lk0VxCFdtUWWmqStLvf/97OaZaB2VjKYsQum/UqZ5LrVGnWj81XzXmtm06W11ZxocPHy7GVNHQmg1dKv5ZK9Y6Gt/5jWkUi9+YRrH4jWkUi9+YRrH4jWkUi9+YRrH4jWmUnvr8ilpjwlKDT9DNEJcuXdpVDPT+AeULK29XNRWFkb0QJVRjR1Xt9fjx43LM4eHhYkytQa2SrkLtv9i3b18xpvYHACxevLgYU6nWp0+fLsZU1WTQFXpVqrBKMa6l9JYqRNf2XozGd35jGsXiN6ZRLH5jGsXiN6ZRLH5jGsXiN6ZRQtkNUz5YxDFgtI+zBNA+VG/xfDTTbT4w/ebU7/lcn5naw+7QU/H/xeARWzNzqG8TuAzPRzPd5gPTb07TbT4Kv+03plEsfmMapd/i39zn8S/H89FMt/nA9JvTdJtPkb5+5jfG9I9+3/mNMX2iL+KPiLsj4pWI2B0RX+3HHC6bz96IeCkiXoiIrX2awyMRcTQito16bFFE/DIiXu38u7DP83kgIl7rrNMLEfEPPZzP6oj474jYERHbI+KfOo/3ZY3EfPq2RhOl52/7I2IA2AV8FDgIPAtszMyXezqR/z+nvcBQZvbNn42IvwXOAN/PzFs7j/0rcCIzH+z8T3JhZv5zH+fzAHAmM7/eizlcNp8VwIrMfD4i5gPPAZ8A/pE+rJGYz2fo0xpNlH7c+e8Admfmnsw8B/wYuKcP85hWZObTwOWJ4/cAj3Z+fpSRi6uf8+kbmTmcmc93fn4d2AGspE9rJObzjqEf4l8JHBj1+0H6v2gJ/CIinouITX2ey2gGM3MYRi42YFmf5wNwX0S82PlY0LOPIaOJiLXABmAL02CNLpsPTIM1Gg/9EH+M8Vi/LYc7M/N24O+BL3Xe8pq/5DvADcBtwDDwjV5PICKuBn4KfDkzy+V3+jefvq/ReOmH+A8Cq0f9vgo41Id5/C+Zeajz71HgZ4x8NJkOHOl8trz0GfNoPyeTmUcy80JmXgS+S4/XKSJmMSK0H2Tm452H+7ZGY82n32s0Efoh/meBGyNiXUTMBj4LPNGHeQAQEVd1vrAhIq4CPgbo5my94wng3s7P9wI/7+NcLonrEp+kh+sUEQE8DOzIzG+OCvVljUrz6ecaTZS+bPLp2B//BgwAj2Tmv/R8Ev83l/WM3O1hpKDpD/sxn4j4EXAXI1lhR4CvAf8BPAasAfYDn87MnnwJV5jPXYy8nU1gL/DFS5+3ezCfvwGeAV4CLnUsvZ+Rz9k9XyMxn430aY0minf4GdMo3uFnTKNY/MY0isVvTKNY/MY0isVvTKNY/MY0isVvTKNY/MY0yv8AZfmG8htYlK4AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.imshow(output[0, 0].detach(), cmap='gray')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 16, 30, 30])" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output.shape" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "conv = nn.Conv2d(3, 1, kernel_size=3, padding=1)" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 1, 32, 32])" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output = conv(img.unsqueeze(0))\n", "\n", "output.shape" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "with torch.no_grad():\n", " conv.bias.zero_()" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [], "source": [ "with torch.no_grad():\n", " conv.weight.fill_(1.0 / 9.0)" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAF79JREFUeJztnV+MXdV1xn8L4z/YGBvb2Bjj1BD5IVFUSDRCkaiiNGkjGlUikZooeYh4QHFUgdRI6QOiUkOlPiRVkygPVSqnoJAqhND8UVAVtUEoFcoLiUMJfwIFgt1gPLGNwdgQAni8+nCv1WFy1zd3zsyca7K/nzSaO2fdfc4++5xv7r37u2vtyEyMMe1xzqQ7YIyZDBa/MY1i8RvTKBa/MY1i8RvTKBa/MY1i8RvTKBa/MY1i8RvTKOcupnFEXAN8GVgB/Etmfk49f926dblx48aRMfVNw5mZmZHbzzmn/t+1YsWKMta13bnnjh4utT91XqdPn+4U65OIWNJY1/11Hcfq3lH76xrrej2rmGpTjdWJEyd45ZVX6oGcRWfxR8QK4J+APwUOAj+NiLsz8xdVm40bN3LDDTeMjP32t78tj/XSSy+N3L5mzZqyzQUXXFDG1q1bV8Y2bNhQxjZt2rTg/b3++utlrDovgN/85jdlbKlR//BWrlxZxtQ/vVWrVo3cvnr16k7HOnXqVBk7ceJEGavG+NVXXy3bVP8wAF577bUypq6ZilX3/iuvvFK2qV6I7rjjjrLNXBbztv8q4KnMfDozXwPuBK5dxP6MMT2yGPHvAJ6Z9ffB4TZjzJuAxYh/1OeK3/lAFBF7ImJfROx7+eWXF3E4Y8xSshjxHwR2zvr7UuDQ3Cdl5t7MnMrMKfXZ2BjTL4sR/0+B3RFxWUSsAj4G3L003TLGLDedZ/sz81RE3Aj8JwOr77bMfFS1iYhyllLN9Faz+uedd96C24C2a9TsfDW7ff7553c6VjUWoMdDWUDV8dSMvhpH9W6tmtGHelZfzfar8VBukJqdr2b71RiqfnS189R9VTkBysWoxn4hFvGifP7M/AHwg8XswxgzGfwNP2MaxeI3plEsfmMaxeI3plEsfmMaZVGz/V2obBmV0aWsqIquSTPKmjt27NjI7ZdddlnZpkoGAp0kor4NqWyvytJTdqSyttTYq3aVLaoSdNQ90DWx5+jRoyO3q6QZlRSm7g+VLKRi1f2o7tMumYBz8Su/MY1i8RvTKBa/MY1i8RvTKBa/MY3S62z/zMwMJ0+eHBlTSSLVTHXXBB2VCKJm2asZW1XOqjpf0P1Xs9GKtWvXjtyuEnuqNqATpNQMfOVkdF0VWs1iq5n0qh+qTdd7p8uMPtTOlHKsKmdkIePrV35jGsXiN6ZRLH5jGsXiN6ZRLH5jGsXiN6ZRerf6XnjhhZExZUVVdk2XxBLQySpdln568cUXyzbKelGJLKofqv/VOKrxUCjbSyUYVbaXus7KKlMWbBdLTPVDWZhdk3e6jJWyFdW9My5+5TemUSx+YxrF4jemUSx+YxrF4jemUSx+YxplUVZfRBwATgIzwKnMnFLPn5mZKW0ZZWtUVp/KOFM24IYNG8qYqrlXLTWlbCN1XspGUxZhl2xGZZWpGnhda+5V1qIajyNHjpSxZ555pozt37+/jFXHU/eHyqhU/Vd2nhrHCtXHKrYQC3ApfP4/zsznlmA/xpge8dt+YxplseJP4IcR8bOI2LMUHTLG9MNi3/ZfnZmHImIrcE9EPJ6Z981+wvCfwh7Qyz0bY/plUa/8mXlo+PsI8D3gqhHP2ZuZU5k5pdZmN8b0S2fxR8S6iFh/5jHwAeCRpeqYMWZ5Wczb/m3A94bWwrnAHZn5H6pBZpa2nbJJKrrYJ6Az5lQxy8rSU8VHV65cWcaULaPsJpXFVll9aqyUValsRZVdWJ23WqLs0KFDZezRRx/t1K76qHnhhReWbRTqmikbUMWq+1FdF2UDjkvnPWTm08AVi+6BMWYi2OozplEsfmMaxeI3plEsfmMaxeI3plF6LeCZmaXloSyg6stBKitOWWVq3bRqXUCobS9lh6nMQ2U5dikkCrU9tBQFHxdC1X9l6arsQmVvdrExlT2rYupaq36o+7G6j1XR1Sq2kOvsV35jGsXiN6ZRLH5jGsXiN6ZRLH5jGqX32f5qaSI1m1vN9qvZUDVbrmb71Wx01e75558v26gaBiqmZpxVanS1T1W38IILLuh0LDXzXTkSanzVNduyZUsZ27ZtWxm79NJLR25X56X6eOzYsU7tutQFVPfAUuBXfmMaxeI3plEsfmMaxeI3plEsfmMaxeI3plF6t/qq5Adl9VX14LokuIBOwFD2YZWAcfLkybJNl4QOgIsuuqiMKWuuqk2n2qhafIrKtoV6jJUtqpK7lFW5devWMrZr166R29W9c/z48TKm+qiSuJStW/VF7a+yKp3YY4yZF4vfmEax+I1pFIvfmEax+I1pFIvfmEaZ1+qLiNuAPweOZOY7hts2Ad8CdgEHgI9m5gvz7Utl9SkLRVlpFV0XBe1SO08tyaWsQ7XkkrKGNm7cWMbWr18/cnvX8VBWpbLtqnp8amktVcNP1VaszlnF1HXpYrHNF1OZpJXl26XuoroX5zLOK//XgGvmbLsJuDczdwP3Dv82xryJmFf8mXkfMPdf/LXA7cPHtwMfWuJ+GWOWma6f+bdl5jTA8Hf9FStjzFnJsn+9NyL2AHug++dOY8zS0/WV/3BEbAcY/j5SPTEz92bmVGZOLXdZImPM+HQV/93AdcPH1wHfX5ruGGP6Yhyr75vAe4EtEXEQ+CzwOeCuiLge+BXwkXEOlpll0UdlsVUZYiorTmX1dS38WfVdHUvZcps2bSpjVXYewNq1a8tYlaGnzlllVE5PT5exgwcPlrEqM05lzKlxVJmHKkuzstiUddglaxJgx44dZUxlEVaFP1WbqpDoQj5azyv+zPx4EXr/2Ecxxpx1+Bt+xjSKxW9Mo1j8xjSKxW9Mo1j8xjRKrwU8I6LMZFOZVJUtozLflLWlrCG13lpl9ansPHVeav25rgU8q+MpW1TZbypz7+jRo2WsytBTGZrqvBRdrrUaD4WyHJV1q9pVfVH36YEDB0ZuX+qsPmPM7yEWvzGNYvEb0ygWvzGNYvEb0ygWvzGN0rvVV2VuKbtsIfbFGbpkCUJt56mYykZTGXhq/bmua+tVNRO6Zswp+0rZZeq8K1RGmrJ11ThW46EKk7788stlTJ2zKripallU96q6F9X9PS5+5TemUSx+YxrF4jemUSx+YxrF4jemUXqd7VeoGfguSRjKPehKNRut6sGpWWo1A6xmjrvUIFQzx2q2/5JLLiljmzdvLmMqWahCzcCrc+7iEqjls6ol5eaLvfBCvWKdWo6uumbKoakS0BbiAviV35hGsfiNaRSL35hGsfiNaRSL35hGsfiNaZRxluu6Dfhz4EhmvmO47Rbgk8CZIm43Z+YP5tvXOeecUyZ8qESWyq5RySOqdp5aBklZOZWlpOrtKatP2TJdagmqfSq7VI2HskzXr19fxqo+qmumEmrUdVG2XZeahl2TzJRdrWoXVpaeGt/KJlYW8VzGeeX/GnDNiO1fyswrhz/zCt8Yc3Yxr/gz8z6gLuFqjHlTspjP/DdGxEMRcVtE1EuXGmPOSrqK/yvAW4ErgWngC9UTI2JPROyLiH3qc5sxpl86iT8zD2fmTGaeBr4KXCWeuzczpzJzaiFrhxtjlpdO4o+I7bP+/DDwyNJ0xxjTF+NYfd8E3gtsiYiDwGeB90bElUACB4BPjXOwVatW8Za3vGVkTGWPVZaHsgdVpp2y0Z577rkyVmWdqXc06qOOWgpLZbipbK8qe0z1Q1lUyjpStldlpSnL66WXXipjKqvvyJEjZawaD5Vlp1BWpeqjilX399atWxfcRt0bc5lX/Jn58RGbbx37CMaYsxJ/w8+YRrH4jWkUi9+YRrH4jWkUi9+YRum1gOd5553HFVdcMTK2ZcuWsl1lr6iMOWXJHDt2rIw98cQTZezJJ58cuf3EiRNlG2WxqcKZKitRxapsuq5ZccpyVHZZZfWpsVcWrMo8VJZpdd6qH2o8lIXctZBrdd7Kyq7OS2V8zsWv/MY0isVvTKNY/MY0isVvTKNY/MY0isVvTKP0avWtWbOG3bt3j4xV26EuqKgymLoWYdy/f38Zq+whZTWpPirLTvVxw4YNZayyD5VV9uKLL5YxleWo2lU2oCrSqcZKjYdaF7CyvlR2oVpzT42HsuZUrBoTZdtVsYVkK/qV35hGsfiNaRSL35hGsfiNaRSL35hG6XW2/9xzz2Xz5s0jY9u2bSvbVTXmVO05hZr5VjO9v/71r0duV7P9Xeu6qaW8VBLUhReOXkJB7U/Nlh8+fLiMqYSmKllFjf3GjRvLmLrWKlYdTyVVqVl2lSClrqdqV7kman+Vm7UQTfiV35hGsfiNaRSL35hGsfiNaRSL35hGsfiNaZRxluvaCXwduBg4DezNzC9HxCbgW8AuBkt2fTQza5/s//e34E5W9eBUfTlleSg7T9V2qywxtcyUsthUkovap0ouUTZghUpyOXr0aBlT/a8STFSdu4svvriMVctTgV4urUL1o7JL50NZc2qsKttOaaWLjuYyziv/KeAzmfk24N3ADRHxduAm4N7M3A3cO/zbGPMmYV7xZ+Z0Zj4wfHwSeAzYAVwL3D582u3Ah5ark8aYpWdBn/kjYhfwTuB+YFtmTsPgHwRQLylqjDnrGFv8EXE+8B3g05lZf6/zd9vtiYh9EbFPfdY2xvTLWOKPiJUMhP+NzPzucPPhiNg+jG8HRi6Snpl7M3MqM6e6TqQYY5aeecUfg2nFW4HHMvOLs0J3A9cNH18HfH/pu2eMWS7Gyeq7GvgE8HBEPDjcdjPwOeCuiLge+BXwkfl2lJmlBVfZeVDbRmpZJWWtqI8fKqOrqp2nlpnqWitOWX2qTltlfyobSo2j6r/K0KtQtpzK6quyQUHXx6v6qOonqneo6lhqjFV2ZGUHd8kIVW3mMq/4M/PHQGUqvn/sIxljzir8DT9jGsXiN6ZRLH5jGsXiN6ZRLH5jGqXXAp5QWy8qC6+yQlSbrhl/69atK2OXXHLJyO2rVq0q2yirTNleqiiostgqq1LZkSqmrKO1a9eWserclGWnsvp27txZxpQ1V1m+avkvtdSbOme1T1XAs7Ju1XXuYrPOxa/8xjSKxW9Mo1j8xjSKxW9Mo1j8xjSKxW9Mo/Rq9amsvi721YoVK8o2yqJSxQ9VuyoTTGWjqUKRVZYg6CKd1dpuUGdHKstRjWPX7LfqvNU5q/Xz1P2hshKrduqclfWpsi2V5avOrbJFlT2o+jEufuU3plEsfmMaxeI3plEsfmMaxeI3plF6n+2vZl9VDb8qpmrZnThRVxdXMTXLXrVTs8NqBli5BMqRUDPH1ay+mjnukqAD2smoYippRjkSBw8eLGNdEnHUdVGOj1rOTY3Vpk2byljlgKglyipNLGQZL7/yG9MoFr8xjWLxG9MoFr8xjWLxG9MoFr8xjTKv1RcRO4GvAxcDp4G9mfnliLgF+CRwdPjUmzPzB2pfp0+fLmvrKfutsoDUcldPPfVUGdu/f38Ze/bZZ8vY0aNHR25XiSXKelH1ApXdpGrFVRabGitllSk7r4tVqeonKptV1TRUY1VZbMqWU0u2VfUkQY9jVf8R4PLLLx+5XdmDS8E4Pv8p4DOZ+UBErAd+FhH3DGNfysx/XL7uGWOWi3HW6psGpoePT0bEY8CO5e6YMWZ5WdBn/ojYBbwTuH+46caIeCgibouIOrnbGHPWMbb4I+J84DvApzPzBPAV4K3AlQzeGXyhaLcnIvZFxD5VhMIY0y9jiT8iVjIQ/jcy87sAmXk4M2cy8zTwVeCqUW0zc29mTmXmlKriYozpl3nFH4Np21uBxzLzi7O2b5/1tA8Djyx994wxy8U4s/1XA58AHo6IB4fbbgY+HhFXAgkcAD41345Onz5dWnqHDx8u21WZVNPT02Wbxx9/vIypduqjSZVFqDLmVLaispSUzaOy8KradMoOUxaVquGnbMBqrKrls0CPfdc6g1VM9UNdT9VO9VFZnNU7YrW/anzVPTWXcWb7fwyMMm2lp2+MObvxN/yMaRSL35hGsfiNaRSL35hGsfiNaZReC3ieOnWK48ePl7GKKqPr0KFDC24DulCkysKrrC1lNVVZjKCLSKovRKlYtYyT6qPKVFN2kyokWlmLyvpUFpuyI1U/qvNW95vKMFX3lUKNf3Ufqz5W46uWNfudfYz9TGPM7xUWvzGNYvEb0ygWvzGNYvEb0ygWvzGN0qvVNzMzUxaSrCwqqAtkKitEFcdU7dT6f5W9otqoTDVlsSnLUVmElR25devWso3qf5djQV3cU2UrqnXwVB+VDVgVO1Xn1fWaqftKredYnXeX4q+qf3PxK78xjWLxG9MoFr8xjWLxG9MoFr8xjWLxG9MovVp9mdmp8GBlsSmrSWWjqbXuVMZflZHW1bJTGW5q/T+V/Vad9+bNm8s2auy7FvesYqrNli1bypjKjlRjVVl66rooO1JdM3XvqGy7yv5WY19pwlafMWZeLH5jGsXiN6ZRLH5jGsXiN6ZR5p3tj4g1wH3A6uHzv52Zn42Iy4A7gU3AA8AnMrPOvjhzwGIGU83OVzOzamZTzeirmm9q5riacVZJGwuZfZ2NShJRs8rV7LZarksluailwdQ4VuOv6g+uXr26jKnZeZWIUyWMqWumksIuuuiiMqbGSjkqlROgHI7K8Vnq2f5Xgfdl5hUMluO+JiLeDXwe+FJm7gZeAK4f+6jGmIkzr/hzwJl/nyuHPwm8D/j2cPvtwIeWpYfGmGVhrM/8EbFiuELvEeAe4JfA8cw88970ILBjebpojFkOxhJ/Zs5k5pXApcBVwNtGPW1U24jYExH7ImKfWt7YGNMvC5rtz8zjwH8B7wY2RsSZ2btLgZErD2Tm3sycyswpNZFijOmXecUfERdFxMbh4/OAPwEeA34E/MXwadcB31+uThpjlp5xEnu2A7dHxAoG/yzuysx/j4hfAHdGxN8D/w3cOs4BleVRUVkhyg5TiRSqD8pyrGxK9Y5GJbIolO2lbMxqTFQbdSxlAyqqBBiVlNTVnlUJXl2sVnUPqGutkn662MFd7gFlic5lXvFn5kPAO0dsf5rB539jzJsQf8PPmEax+I1pFIvfmEax+I1pFIvfmEaJLtZb54NFHAX+d/jnFuC53g5e4368EffjjbzZ+vEHmVmnHs6iV/G/4cAR+zJzaiIHdz/cD/fDb/uNaRWL35hGmaT4907w2LNxP96I+/FGfm/7MbHP/MaYyeK3/cY0ykTEHxHXRMT/RMRTEXHTJPow7MeBiHg4Ih6MiH09Hve2iDgSEY/M2rYpIu6JiCeHvy+cUD9uiYhnh2PyYER8sId+7IyIH0XEYxHxaET81XB7r2Mi+tHrmETEmoj4SUT8fNiPvxtuvywi7h+Ox7ciolvK5Rkys9cfYAWDMmCXA6uAnwNv77sfw74cALZM4LjvAd4FPDJr2z8ANw0f3wR8fkL9uAX4657HYzvwruHj9cATwNv7HhPRj17HBAjg/OHjlcD9DAro3AV8bLj9n4G/XMxxJvHKfxXwVGY+nYNS33cC106gHxMjM+8Dnp+z+VoGhVChp4KoRT96JzOnM/OB4eOTDIrF7KDnMRH96JUcsOxFcych/h3AM7P+nmTxzwR+GBE/i4g9E+rDGbZl5jQMbkJg6wT7cmNEPDT8WLDsHz9mExG7GNSPuJ8JjsmcfkDPY9JH0dxJiH9UqZFJWQ5XZ+a7gD8DboiI90yoH2cTXwHeymCNhmngC30dOCLOB74DfDozT/R13DH60fuY5CKK5o7LJMR/ENg56++y+Odyk5mHhr+PAN9jspWJDkfEdoDh7yOT6ERmHh7eeKeBr9LTmETESgaC+0Zmfne4ufcxGdWPSY3J8NgLLpo7LpMQ/0+B3cOZy1XAx4C7++5ERKyLiPVnHgMfAB7RrZaVuxkUQoUJFkQ9I7YhH6aHMYlB4blbgccy84uzQr2OSdWPvsekt6K5fc1gzpnN/CCDmdRfAn8zoT5czsBp+DnwaJ/9AL7J4O3j6wzeCV0PbAbuBZ4c/t40oX78K/Aw8BAD8W3voR9/xOAt7EPAg8OfD/Y9JqIfvY4J8IcMiuI+xOAfzd/Oumd/AjwF/BuwejHH8Tf8jGkUf8PPmEax+I1pFIvfmEax+I1pFIvfmEax+I1pFIvfmEax+I1plP8DMQi2q65RHfEAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "output = conv(img.unsqueeze(0))\n", "plt.imshow(output[0, 0].detach(), cmap='gray')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [], "source": [ "conv = nn.Conv2d(3, 1, kernel_size=3, padding=1)\n", "\n", "with torch.no_grad():\n", " conv.weight[:] = torch.tensor([[-1.0, 0.0, 1.0],\n", " [-1.0, 0.0, 1.0],\n", " [-1.0, 0.0, 1.0]])\n", " conv.bias.zero_()" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAGLdJREFUeJztnV2snNV1hp+FweAfgjnYxsfYgHFIBImKYx2hSFQoTdqIRpFIpCZKLiIuUBxVQSpSeoGo1IDUi6RqEuWiSuUUFFKlIWl+FFShNgilQrkhMZS/YFoTY4zrf8DY4d/26sWMpcNh1nvGc3y+Mez3kY7OnL1mf9+ePd86M7PfedeOzMQY0x5njHsAxpjx4OQ3plGc/MY0ipPfmEZx8hvTKE5+YxrFyW9Mozj5jWkUJ78xjXLmXDpHxHXAt4EFwD9n5tfU/ZcsWZITExMDY2+88UbZb/HixQPbzzij/t91/PjxMhYRZUwds+Lo0aNlTD0uNY4zzxztqam+sanOpWKjfgO0mkd1rmPHjo0UU891NQ41v+pc8/GN2GpO1LnOOuusge0HDx7kyJEj9SRPY+Tkj4gFwD8CfwbsAn4bEfdk5pNVn4mJCW6++eaBseeee64818aNGwe2V/8UAF555ZUytnDhwjJ2zjnnlLHqQnr++efLPs8++2wZUxfg8uXLy5ii+kekHnN1IQG8+eabZUxdnIsWLRrYrv65vvzyy2Xs8OHDI8WWLl06sL16EQL4wx/+UMbUfIz6j+3ss88e2K5eVFasWDGw/fbbby/7zGQub/uvBp7OzO2Z+QZwN3D9HI5njOmQuST/RcD0l+td/TZjzDuAuST/oPc4b3sfGBGbImJLRGxRb+uMMd0yl+TfBayd9vcaYPfMO2Xm5sycysypJUuWzOF0xphTyVyS/7fA5RGxLiIWAp8D7jk1wzLGzDcjr/Zn5tGIuAn4T3pS352Z+TvVJyLKlWW1slmtiquV+QMHDpQx9fFj7dq1ZWzZsmUD219//fWyj2JUSUnNVbXirOZKrfarmHrc1WMbVZ5V6o1SilavXj2w/bLLLiv7vPbaa2VMKQsLFiwoY+pxV/Oo5rdSMU5Gqp6Tzp+Z9wL3zuUYxpjx4G/4GdMoTn5jGsXJb0yjOPmNaRQnvzGNMqfV/pPljDPOKM04SmI7cuTIwHZlslBuupdeeqmMKalk5cqVA9svvvjiss8LL7xQxpTp58UXXyxj73nPe8rYueeeO7C9koZAG4yU7KXGWEmO55133kjjUOzZs6eMVfLsqlWryj5Kztu/f38ZU1KfQl2rFaM4Xd9235M+qzHmXYGT35hGcfIb0yhOfmMaxclvTKN0utqv+MAHPlDG9u7dO7BdmXdU2SplIKnOBXV5J2WaUaiVb3XMquwT1Kv6yk6t5uPQoUMjxSrVQRmFlLKgVtJHUX3UONS5Rq0lqKgedzWHUD/PJ6M4+JXfmEZx8hvTKE5+YxrFyW9Mozj5jWkUJ78xjdKp1Hfs2LFSHnrve99b9lM7oVRUJhzQBpKnn366jO3atWtg+yWXXFL2UTKUkpvOP//8MjbKlmKqJqCSr1S9Q1VjrtpxSElRyiikzrVmzZoyVu1so8xdKqaeM/W8KPNOdcz5rnbtV35jGsXJb0yjOPmNaRQnvzGN4uQ3plGc/MY0ypykvojYARwBjgFHM3NK3f/o0aOl1KeknEoKUe42JQ9OTk6WMUU1RuUEVPKVGn9Vow301lWV81DJm0p+U/2UHFm5KpUT86mnnipjaozr168vY1WtvmqeQEt2ymmnXH1qi7XqOlDO1EqeVZLuTE6Fzv8nmXnwFBzHGNMhfttvTKPMNfkT+GVEPBQRm07FgIwx3TDXt/3XZObuiFgJ3BcRT2XmA9Pv0P+nsAl0vXljTLfM6ZU/M3f3f+8Hfg5cPeA+mzNzKjOn1CKWMaZbRk7+iFgSEeeeuA18HHjiVA3MGDO/zOVt/4XAz/uS2pnAv2bmf6gOx48fL2Uq5X6r5CYl51VbfAFMTEyUMeUGrOQh5fRSH3UOHqxFEhVTElAl9SgZSo1fOcuUJFY9N2q7q23btpWxyp0HcO2115axav6Vy27RokVlTElpyg2o5r+aYyWzVvnSidSXmduBq0btb4wZL5b6jGkUJ78xjeLkN6ZRnPzGNIqT35hG6XyvvkryUFJI5ehSfZR7TO2Dd+GFF5ax6nxKXlFSn5KGlOtMSZWVTKVko/nYf67qpx6Xil1xxRVlTM1x9djUNaBk4sOHD5cxVexUybPVczbqfpPD4ld+YxrFyW9Mozj5jWkUJ78xjeLkN6ZROl3tz8yylplavaxWPZcuXVr2eeGFF8qYqqunVoFfe+21ge1qJV2NUW3zpWq+KUPTKPXgRlUd1HNWjaOaQ9BzPzVVl4es6vQBbN26dWC7qglYbcsG2pik5lgZpJSRqGIU5WkmfuU3plGc/MY0ipPfmEZx8hvTKE5+YxrFyW9Mo3Qu9VW1x5SkVMlGauskJa2o+m0qVkl6SipTMuCaNWvKmKqrp0xLldSjtkN79dVXy5jaGkzNVWU+UhKmMlVt3LixjCmJsDLiqK3SlLFHzYd6ztT5qlqIqq5l9Xxa6jPGzIqT35hGcfIb0yhOfmMaxclvTKM4+Y1plFmlvoi4E/gksD8zP9hvmwB+BFwK7AA+m5m1Va5PZpauLrX1UyUPqT7nn39+GVPSnJKvqlpxShpS0ouSqFRMORYr2U7JRkrqUzXrlOxVjV9JXuvWrStjF1xwQRlT46+cdqM6MZVUqaRn5SKs5lE5IKt6gep5nskwr/zfA66b0XYLcH9mXg7c3//bGPMOYtbkz8wHgJkvNdcDd/Vv3wV86hSPyxgzz4z6mf/CzNwD0P9db21rjDktmfcFv4jYFBFbImKL+gxjjOmWUZN/X0RMAvR/l7WNMnNzZk5l5pRaxDLGdMuoyX8PcEP/9g3AL07NcIwxXTGM1PdD4CPA8ojYBXwV+Brw44i4EdgJfGaYkylXnypKWW25pLaZUu8ylINQyYBVocXly5eXfUZxc4GW2EbZqkk9rt27d5cxJSsqF1v1fK5YsaLso7YhUxLsoUOHylgl26lim8pduGzZsjKm5Dz1kbe6HpUc+cwzzwxsV+7Ntx1/tjtk5ueL0MeGPosx5rTD3/AzplGc/MY0ipPfmEZx8hvTKE5+Yxql0wKeEVFKLEoSq+RBJYcplESoHGKVq0+50ZQMqMavJDYl9VWSmJL6Dh48WMbUfKxcWX+ru5L6lHy1evXqMqZkNPXYKtQ1oMZ43nnnlTHlBnzuuefKWCVVKmdqtWegch3OxK/8xjSKk9+YRnHyG9MoTn5jGsXJb0yjOPmNaZTOpb5KRhnFqaaKFSopRzm61DGr2OLFi8s+ygWm3GhKYlNzNco8KgfkqIUuq6KUO3fuLPtceeWVZUzJio8++mgZq/Y1VI9ZyWWqkKgq4KnOV12rqvirkj6Hxa/8xjSKk9+YRnHyG9MoTn5jGsXJb0yjdLrar1A1zqqVTWUGUltJqVVUtbpdGXvUOFS9QFVvTdX3U4pEdUz1uFTNOjUO9bj37t07sH3btm1ln6uvvrqMVXM/2ziqlXulpiilSF2nyiCljlmZhVSNxEo9UM/X2+479D2NMe8qnPzGNIqT35hGcfIb0yhOfmMaxclvTKMMs13XncAngf2Z+cF+223AF4ETrolbM/PeIY5VbnmlzBSV5KEMNcrg8vLLL5cxNY5KLlMyzp49e8qYMvaMKvVVEpaSw5RBR0mEaou1ShJTEqyaDyWZqsdWXTvKGKOuKyUTq+da9Vu1atXAdnUNV/N7MsauYV75vwdcN6D9W5m5of8za+IbY04vZk3+zHwAqEvJGmPekczlM/9NEfFYRNwZEXWNYWPMacmoyf8dYD2wAdgDfKO6Y0RsiogtEbFFfaXSGNMtIyV/Zu7LzGOZeRz4LlB+KTszN2fmVGZOVYt9xpjuGSn5I2Jy2p+fBp44NcMxxnTFMFLfD4GPAMsjYhfwVeAjEbEBSGAH8KVhTrZw4UIuvvjigTElbVUxVYtPOeZefPHFkWIrVqwY2P7888+XfdS2W9XWWqClKOUsq2RAVUNOyUNKVlyyZEkZq+aqagfttFMyoHpHWY1fXR9q2zB1LnVMJUdW51MSciVXn4zUN2vyZ+bnBzTfMfQZjDGnJf6GnzGN4uQ3plGc/MY0ipPfmEZx8hvTKJ0W8Fy8eDEbNmwYGNu9e3fZr5JQlKyhHGfVFk4A27dvL2OV+03JP0rOU04vJQ0p11kll6lxKIlK9VPzX83VmjVryj5qHpXspWTRI0eODGxXLkG1JZcqdjoxMVHG1FxVUraSZ6utwU5mGy+/8hvTKE5+YxrFyW9Mozj5jWkUJ78xjeLkN6ZROpX6zjnnHC6//PKBsZ07d5b9KilKOQGV5KGce08++WQZq2Seq666quyjUHJTJeUAnH9+XTipkpT2799f9lEuwVEKmkK9t97y5cvLPtWedaClMiUDVrKocgmquVKuRLWfoCpcWjk/lXRYuWOV03UmfuU3plGc/MY0ipPfmEZx8hvTKE5+Yxql09X+BQsWlCu6o2yhVW3FBNokos6lDEZVPbhly5aVfZSJSNX+UwYYpQRUxh61Iq7q9CmDlFqBr54ztYI9OTlZxpTBSNX+q86n+qjHrGoyjjpXlfqkru9K8bGxxxgzK05+YxrFyW9Mozj5jWkUJ78xjeLkN6ZRhtmuay3wfWAVcBzYnJnfjogJ4EfApfS27PpsZtaOGXo166q6dUrKqWQqVQNPmU5GNYlUdelUvb1RpT4lvylDUyUPqflQEpXqp6SoKqaMPevXry9jSmJTc1wZe5RBR9VIVEYnZdRS/So5Ul2nVb6onJjJMK/8R4GvZOYVwIeBL0fElcAtwP2ZeTlwf/9vY8w7hFmTPzP3ZObD/dtHgK3ARcD1wF39u90FfGq+BmmMOfWc1Gf+iLgU+BDwIHBhZu6B3j8IYOWpHpwxZv4YOvkjYinwU+DmzDx8Ev02RcSWiNiiPuMaY7plqOSPiLPoJf4PMvNn/eZ9ETHZj08CA8ufZObmzJzKzCm1GYIxpltmTf7oLQXfAWzNzG9OC90D3NC/fQPwi1M/PGPMfDGMq+8a4AvA4xHxSL/tVuBrwI8j4kZgJ/CZYU6oJI+TRdVFU+dRUs773ve+MlZJfWqLL/VRR0llynVWbUGlUFs/qeOpunqqlmD13Kg+q1atKmPPPPNMGXvppZfKWPW41TWwbt26MrZ3794ytmvXrjKm6iRW8qeSvw8fHvzJW0nVM5k1+TPz10AlBH9s6DMZY04r/A0/YxrFyW9Mozj5jWkUJ78xjeLkN6ZROi3gmZmlFKFkr8rdpOQw5W5S2ypdcsklZawq1KkKgqqCipV0CLBv374ypqS5Si6rtjwDLUMpp51ynVWS2JIlS8o+qjCpGqOa/1GuN3XtqPErObWS5qB2cKrH1ZWrzxjzLsTJb0yjOPmNaRQnvzGN4uQ3plGc/MY0ymkj9SlJrJJllFyj3E3K8af2+KukqIULF5Z9lKy4cmVd/Eg5xFQxy4suumhge7UfHOhioUrqU067SppTBUGfffbZMqYes5LfqutqPvbjU8VJ1fVdFWTdv39giQygvoYt9RljZsXJb0yjOPmNaRQnvzGN4uQ3plE6X+1XK8sV1QqrOpaq76e2u1KrspU5Zu3atSONozIKga79p2KVGUSZgZShRtXwU+Oo5lGZgR566KEyprbCUrX/qvlQprBDhw6VMWXeUaqPqg1Zbff2yCOPDGyHevxqfmfiV35jGsXJb0yjOPmNaRQnvzGN4uQ3plGc/MY0yqxSX0SsBb4PrAKOA5sz89sRcRvwReBA/663Zua96ljHjx8vpRdltqnqlakaZ2oLJ1U7T8k1VR02VV9OGVmUSeT9739/GduxY0cZqww8an4nJyfLmDKKKONJ1U8ZjJRRSBm1RjH2qOtD1TtUMrGa48pwBbWcqqS+U1HDbxid/yjwlcx8OCLOBR6KiPv6sW9l5j8MfTZjzGnDMHv17QH29G8fiYitQP1vzBjzjuCkPvNHxKXAh4AH+003RcRjEXFnRNTbrxpjTjuGTv6IWAr8FLg5Mw8D3wHWAxvovTP4RtFvU0RsiYgt6vOeMaZbhkr+iDiLXuL/IDN/BpCZ+zLzWGYeB74LXD2ob2ZuzsypzJxSe7MbY7pl1uSP3pL0HcDWzPzmtPbpS8SfBp449cMzxswXw6z2XwN8AXg8Ik5oD7cCn4+IDUACO4AvzXagY8eOlY4pJZdV0ouS+pTsomrnLV68uIxVW2iprZiUC0yhnILKTbd9+/aB7UqiGsUVB9oBWVGND7TTbsWKFWVMyanVdTDqtaPqRlbuPNDyclXnUR2vesxKPn7bMWa7Q2b+Ghh0RKnpG2NOb/wNP2MaxclvTKM4+Y1pFCe/MY3i5DemUTot4KmkPlV4sHIqKalJbZOl5DwlyVTyipKGVOFM5fRSxT2VfFg57dRcqeMpCVbNVdWvkktBPy/q+VRjrJ4bJYnt3r27jKl+V1xxRRlTY6y2B1PXwOrVqwe2KylyJn7lN6ZRnPzGNIqT35hGcfIb0yhOfmMaxclvTKN0LvVVDjjlzKqKalZyB2j5TRWDVAUQqz3tDh48WPZRDjwl9SnpU7m9JiYmBraropSqmGUlQ6lzQS05jVJ8FLTspeSt6rpSsqIahyoWquRItddgJX+q+hfVdWqpzxgzK05+YxrFyW9Mozj5jWkUJ78xjeLkN6ZROpX6MrOU2ZSzrJIvlKyhimpW+5zN1q/aU00Vx3zllVdGio0q9V1wwQUD29UY1bkOHDhQxpScWjn+VGFSJZmqmHLMVRKhciQqGVDJxErWPZnCmieonkuFksxn4ld+YxrFyW9Mozj5jWkUJ78xjeLkN6ZRZl0ajIhzgAeAs/v3/0lmfjUi1gF3AxPAw8AXMrNeRqe34lytVKvV0GprpVFr8akVfWXAqKiMR6DNL0p1qGodgl5xrhQQtd2Vmiu1rZUyQVXPpzLoKBVj7969ZUwpAdX4K+UGYM2aNWVMMarCVK3QK2PPokWLBrYrVedt9x3iPq8DH83Mq+htx31dRHwY+Drwrcy8HHgRuHHosxpjxs6syZ89TrwcntX/SeCjwE/67XcBn5qXERpj5oWh3iNExIL+Dr37gfuA3wOHMvPEtzx2AbU53Rhz2jFU8mfmsczcAKwBrgYGFSgf+AEwIjZFxJaI2KI+ExljuuWkVvsz8xDwX8CHgWURcWKlYg0wcKeDzNycmVOZOaUqnRhjumXW5I+IFRGxrH97EfCnwFbgV8Bf9O92A/CL+RqkMebUM4wLYBK4KyIW0Ptn8ePM/PeIeBK4OyL+Dvhv4I5hTqjkoYpK2lKSlzI4KEOQOma15ZWSypSEqWQZVTtPbQFWmW1GqXMH2jSjnsvKLFRJVKDlN/WY1Rir2oXKlLRy5cqTPh7Aq6++WsZGqRs5iuysTFozmTX5M/Mx4EMD2rfT+/xvjHkH4m/4GdMoTn5jGsXJb0yjOPmNaRQnvzGNEqNIbyOfLOIA8Gz/z+VAbcfqDo/jrXgcb+WdNo5LMrO2cE6j0+R/y4kjtmTm1FhO7nF4HB6H3/Yb0ypOfmMaZZzJv3mM556Ox/FWPI638q4dx9g+8xtjxovf9hvTKGNJ/oi4LiL+JyKejohbxjGG/jh2RMTjEfFIRGzp8Lx3RsT+iHhiWttERNwXEdv6v+vqjfM7jtsi4v/6c/JIRHyig3GsjYhfRcTWiPhdRPxVv73TORHj6HROIuKciPhNRDzaH8ft/fZ1EfFgfz5+FBF15dhhyMxOf4AF9MqAXQYsBB4Frux6HP2x7ACWj+G81wIbgSemtf09cEv/9i3A18c0jtuAv+54PiaBjf3b5wL/C1zZ9ZyIcXQ6J0AAS/u3zwIepFdA58fA5/rt/wT85VzOM45X/quBpzNze/ZKfd8NXD+GcYyNzHwAmGnYv55eIVToqCBqMY7Oycw9mflw//YResViLqLjORHj6JTsMe9Fc8eR/BcBz037e5zFPxP4ZUQ8FBGbxjSGE1yYmXugdxECdUWJ+eemiHis/7Fg3j9+TCciLqVXP+JBxjgnM8YBHc9JF0Vzx5H8g0rbjEtyuCYzNwJ/Dnw5Iq4d0zhOJ74DrKe3R8Me4BtdnTgilgI/BW7OzLFVex0wjs7nJOdQNHdYxpH8u4Dpm7SXxT/nm8zc3f+9H/g5461MtC8iJgH6v/ePYxCZua9/4R0HvktHcxIRZ9FLuB9k5s/6zZ3PyaBxjGtO+uc+6aK5wzKO5P8tcHl/5XIh8Dngnq4HERFLIuLcE7eBjwNP6F7zyj30CqHCGAuinki2Pp+mgzmJXqHDO4CtmfnNaaFO56QaR9dz0lnR3K5WMGesZn6C3krq74G/GdMYLqOnNDwK/K7LcQA/pPf28U1674RuBC4A7ge29X9PjGkc/wI8DjxGL/kmOxjHH9N7C/sY8Ej/5xNdz4kYR6dzAvwRvaK4j9H7R/O3067Z3wBPA/8GnD2X8/gbfsY0ir/hZ0yjOPmNaRQnvzGN4uQ3plGc/MY0ipPfmEZx8hvTKE5+Yxrl/wHo7ZuXD1PlgwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "output = conv(img.unsqueeze(0))\n", "plt.imshow(output[0, 0].detach(), cmap='gray')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [], "source": [ "pool = nn.MaxPool2d(2)" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 3, 16, 16])" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output = pool(img.unsqueeze(0))\n", "\n", "output.shape" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "ellipsis is not a Module subclass", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 6\u001b[0m \u001b[0mnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTanh\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 7\u001b[0m \u001b[0mnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mMaxPool2d\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 8\u001b[1;33m ...)\n\u001b[0m", "\u001b[1;32m~\\Miniconda3\\envs\\book\\lib\\site-packages\\torch\\nn\\modules\\container.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, *args)\u001b[0m\n\u001b[0;32m 51\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 52\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0midx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 53\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0madd_module\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0midx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmodule\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 54\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 55\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_get_item_by_idx\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0miterator\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0midx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Miniconda3\\envs\\book\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36madd_module\u001b[1;34m(self, name, module)\u001b[0m\n\u001b[0;32m 171\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mModule\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 172\u001b[0m raise TypeError(\"{} is not a Module subclass\".format(\n\u001b[1;32m--> 173\u001b[1;33m torch.typename(module)))\n\u001b[0m\u001b[0;32m 174\u001b[0m \u001b[1;32melif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_six\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstring_classes\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 175\u001b[0m raise TypeError(\"module name should be a string. Got {}\".format(\n", "\u001b[1;31mTypeError\u001b[0m: ellipsis is not a Module subclass" ] } ], "source": [ "model = nn.Sequential(\n", " nn.Conv2d(3, 16, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " nn.Conv2d(16, 8, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " ...)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " nn.Conv2d(3, 16, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " nn.Conv2d(16, 8, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " # WARNING: something missing here\n", " nn.Linear(512, 32),\n", " nn.Tanh(),\n", " nn.Linear(32, 2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sum([p.numel() for p in model.parameters()])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model(img.unsqueeze(0))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)\n", " self.act1 = nn.Tanh()\n", " self.pool1 = nn.MaxPool2d(2)\n", " self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)\n", " self.act2 = nn.Tanh()\n", " self.pool2 = nn.MaxPool2d(2)\n", " self.fc1 = nn.Linear(8 * 8 * 8, 32)\n", " self.act4 = nn.Tanh()\n", " self.fc2 = nn.Linear(32, 2)\n", "\n", " def forward(self, x):\n", " out = self.pool1(self.act1(self.conv1(x)))\n", " out = self.pool2(self.act2(self.conv2(out)))\n", " out = out.view(-1, 8 * 8 * 8)\n", " out = self.act4(self.fc1(out))\n", " out = self.fc2(out)\n", " return out" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = Net()\n", "\n", "sum([p.numel() for p in model.parameters()])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)\n", " self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)\n", " self.fc1 = nn.Linear(8 * 8 * 8, 32)\n", " self.fc2 = nn.Linear(32, 2)\n", " \n", " def forward(self, x):\n", " out = F.max_pool2d(torch.tanh(self.conv1(x)), 2)\n", " out = F.max_pool2d(torch.tanh(self.conv2(out)), 2)\n", " out = out.view(-1, 8 * 8 * 8)\n", " out = torch.tanh(self.fc1(out))\n", " out = self.fc2(out)\n", " return out" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = Net()\n", "model(img.unsqueeze(0))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n", " shuffle=True)\n", "\n", "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)\n", " self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)\n", " self.fc1 = nn.Linear(8 * 8 * 8, 32)\n", " self.fc2 = nn.Linear(32, 2)\n", " \n", " def forward(self, x):\n", " out = F.max_pool2d(torch.relu(self.conv1(x)), 2)\n", " out = F.max_pool2d(torch.relu(self.conv2(out)), 2)\n", " out = out.view(-1, 8 * 8 * 8)\n", " out = torch.tanh(self.fc1(out))\n", " out = self.fc2(out)\n", " return out\n", " \n", "model = Net()\n", "\n", "learning_rate = 1e-2\n", "\n", "optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n", "\n", "loss_fn = nn.CrossEntropyLoss()\n", "\n", "n_epochs = 100\n", "\n", "for epoch in range(n_epochs):\n", " for imgs, labels in train_loader:\n", " outputs = model(imgs)\n", " loss = loss_fn(outputs, labels)\n", " \n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " print(\"Epoch: %d, Loss: %f\" % (epoch, float(loss)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n", " shuffle=False)\n", "\n", "correct = 0\n", "total = 0\n", "\n", "with torch.no_grad():\n", " for imgs, labels in train_loader:\n", " outputs = model(imgs)\n", " _, predicted = torch.max(outputs, dim=1)\n", " total += labels.shape[0]\n", " correct += int((predicted == labels).sum())\n", " \n", "print(\"Accuracy: %f\" % (correct / total))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64,\n", " shuffle=False)\n", "\n", "correct = 0\n", "total = 0\n", "\n", "with torch.no_grad():\n", " for imgs, labels in val_loader:\n", " outputs = model(imgs)\n", " _, predicted = torch.max(outputs, dim=1)\n", " total += labels.shape[0]\n", " correct += int((predicted == labels).sum())\n", " \n", "print(\"Accuracy: %f\" % (correct / total))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)\n", " self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)\n", " self.fc1 = nn.Linear(8 * 8 * 8, 32)\n", " self.fc2 = nn.Linear(32, 2)\n", " \n", " def forward(self, x):\n", " out = F.max_pool2d(torch.relu(self.conv1(x)), 2)\n", " out = F.max_pool2d(torch.relu(self.conv2(out)), 2)\n", " out = out.view(-1, 8 * 8 * 8)\n", " out = torch.tanh(self.fc1(out))\n", " out = self.fc2(out)\n", " return out\n", " \n", "model = Net()\n", "sum([p.numel() for p in model.parameters()])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " nn.Conv2d(3, 16, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " nn.Conv2d(16, 8, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " nn.Linear(8*8*8, 32),\n", " nn.Tanh(),\n", " nn.Linear(32, 2))\n", "\n", "model(img.unsqueeze(0))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 2 }