{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%matplotlib inline\n", "from matplotlib import pyplot as plt\n", "import numpy as np\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "\n", "torch.set_printoptions(edgeitems=2)\n", "torch.manual_seed(123)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class_names = ['airplane','automobile','bird','cat','deer',\n", " 'dog','frog','horse','ship','truck']" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n" ] } ], "source": [ "from torchvision import datasets, transforms\n", "data_path = '../data-unversioned/p1ch6/'\n", "cifar10 = datasets.CIFAR10(data_path, train=True, download=True,\n", " transform=transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.4915, 0.4823, 0.4468),\n", " (0.2470, 0.2435, 0.2616))\n", " ]))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n" ] } ], "source": [ "cifar10_val = datasets.CIFAR10(data_path, train=False, download=True,\n", " transform=transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.4915, 0.4823, 0.4468),\n", " (0.2470, 0.2435, 0.2616))\n", " ]))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "label_map = {0: 0, 2: 1}\n", "class_names = ['airplane', 'bird']\n", "cifar2 = [(img, label_map[label]) for img, label in cifar10 if label in [0, 2]]\n", "cifar2_val = [(img, label_map[label]) for img, label in cifar10_val if label in [0, 2]]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "connected_model = nn.Sequential(\n", " nn.Linear(3072, 1024),\n", " nn.Tanh(),\n", " nn.Linear(1024, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 128),\n", " nn.Tanh(),\n", " nn.Linear(128, 2))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(3737474, [3145728, 1024, 524288, 512, 65536, 128, 256, 2])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "numel_list = [p.numel() for p in connected_model.parameters() if p.requires_grad == True]\n", "sum(numel_list), numel_list" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "first_model = nn.Sequential(\n", " nn.Linear(3072, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 2),\n", " nn.LogSoftmax(dim=1))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1574402, [1572864, 512, 1024, 2])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "numel_list = [p.numel() for p in first_model.parameters()]\n", "sum(numel_list), numel_list" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1024, 3072]), torch.Size([1024]))" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "linear = nn.Linear(3072, 1024)\n", "\n", "linear.weight.shape, linear.bias.shape" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "conv = nn.Conv2d(3, 16, kernel_size=3) # <1>\n", "conv" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([16, 3, 3, 3]), torch.Size([16]))" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "conv.weight.shape, conv.bias.shape" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1, 3, 32, 32]), torch.Size([1, 16, 30, 30]))" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img, _ = cifar2[0]\n", "output = conv(img.unsqueeze(0))\n", "img.unsqueeze(0).shape, output.shape" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAGXFJREFUeJztnXt4VeWVxt8lAUEDRQhICtiAYos3Lg14G6yCqPg4g1i1OtWh1hHt6Iyd6Uyljq20tR3tU7XaUgWrI/YRFeu99YYpLVQZJCImYqxcREECJGKEiBhD1vxxDhriXisn+5yzT9Lv/T1PniTrzbf3OvucN/ucvfa3PlFVEELCY59CJ0AIKQw0PyGBQvMTEig0PyGBQvMTEig0PyGBQvMTEig0PyGBQvMTEihF2QwWkdMA3AKgG4DfqOr13t/3LhHtXxatvfWGM3Df6PA+Pe0h3aWbqfXaz37YBxSXmFofDIyMFzn/Qxux3dQ27lhjasW97TsvP2cqQHcjvssZYxxeAP7Zwbs3dLcR398Zkw8ajHiTM+ZD8ygC3qPevqPZ1Jo+dDa509Es3jfiuwFtUclkExL39l4R6QbgDQCTAWwEsBzA+ar6mjWmrFz0+5XR2j+f4uxsWHS4z0jbxAOKbIuMPqq/qU2bcLGpTZbLI+MDnZf0C3jO1L5bcYapHTvpI1OzRwEDjPhqZ4xxeAEAxY7m/UNpNOLjnTFxaXG0J4z4emdMDQabWjNsgz9XscXU3qpxdviyo1k8acTrAf04M/Nn87Z/PIA1qrpOVZsA3A9gahbbI4QkSDbmHwxgQ6vfN6ZjhJAuQDbmj3pr8ZnPECIyQ0QqRaRyR10WeyOE5JRszL8RwNBWvw8BsKntH6nqXFUtV9Xy3tYHUkJI4mRj/uUARojIMBHpAeA8AI/nJi1CSL6JXepT1WYRuQLAM0iV+u5S1VXemG5wrh5/1Rl4WXR4+0j7yuv2I981tTc329qaF+fY2sTow3X+2JPMMcc5hblfT+praqthXzk2CiYA7CvwQ5wx9lEE6h3Nzj7uVf3hjnaUqVRiuan9z2PvRMaLh0aGU5TYj7riVrsK02OMs007Rbse6WE90R0o3mVV51fVJ2EXHQghnRje4UdIoND8hAQKzU9IoND8hAQKzU9IoGR1tb+jbAHwS0ObdKk9rsKoD44a49VW7HLeK99729YeX2drU74TGa+eZZehJo+vMjWvwuNMWMRGR7MqSlOcMYMc7VhH64MDHdU6/l5h0S6jvYAzTW3hY3Y5ddmZ86KFs+wsjv6VnQeOtKWmF20Nbzqa5cJFzpgcwDM/IYFC8xMSKDQ/IYFC8xMSKDQ/IYGS6NX+D94Bnr8mWjv4Onvc5V+Pjs9+yOl/5LVNGudo3rzEp6LDSy+3r+j/vbM572r/TY7mMdmIe9fYvTZefdz+LNE9DQHgMuPRTcah5phxTjfBOqehWPXQC0wNMK72OwfkS6W21jDB1v7qXdF3tlmo2TE88xMSKDQ/IYFC8xMSKDQ/IYFC8xMSKDQ/IYGSaKkPmwH8JFpa6wyb/Q+G4M1+cRrMHeysDrTWKxHOjw5vchrdfWOxsz0n/4Exl7axFhuzSoAAcKi7AJi97NkfPtus+RN2GU/AMNiTsZbC7oV4ntfkcawtmY+8bKE5YqG55hSwabazK2/G1QZHs5Y3yjM88xMSKDQ/IYFC8xMSKDQ/IYFC8xMSKDQ/IYGSValPRNYD2AFgN4BmVS2PvbF7Hc0qvzn91PBNW2p2lmo6+le2tsxaImm1k4eHM4Ow8WZb+5eDbM1aCHmJk0Y13je1Mkdb6WxznNHf7z38yRxzH86xNxi1JnRGnB0dbt7fHLFp9qP25ryZewc4WidcoToXdf6TVNVb0o0Q0gnh235CAiVb8yuAZ0XkJRGZkYuECCHJkO3b/uNVdZOIDASwUEReV9W9bmhN/1PgPwZCOhlZnflVdVP6+1YAjyBiWXZVnauq5VldDCSE5JzY5heR/UWk956fAZwC4NVcJUYIyS+iqvEGigxH6mwPpD4+zFdVY87eJ2Pi7ezHRvweZ4y3zpQzm+6Lc2ztEiN+hLOreqcB5owX3zG1ndX2NkddbGvWBDFv1uTxjvYtR7NmEAJAKaLrkdXYbY65oMqpYY5yOmfiZ44WA++BTXQ0u8cosNTRrBl/MWf7qWpGhdHYn/lVdR2AUXHHE0IKC0t9hAQKzU9IoND8hAQKzU9IoND8hARK7FJfrJ3FLfVNMeK9nTHeTLv3HM1qFgoAJxhxZ+2/rzrVK6fHKO5c4YheM0ij8efhzlpxzlw6t4zpTI5EmRGvQX9zzIkVh9kbPNmaUgkAy23JKr+NdDZ3rqN5DV5rHc1Y5zEfZFrq45mfkECh+QkJFJqfkECh+QkJFJqfkEBJdrkuDy+TNUb8n2Lua4GjPexo1gpPZfaQhy53tucsybWPvaoVBjnLU40w4hc6aRziaB5eWzpLa8a79qCFh8ZL5G7nar/B1Om2NsgZN+dSR7RXAOuU8MxPSKDQ/IQECs1PSKDQ/IQECs1PSKDQ/IQESucp9TU72g4j7k2y8HB6+LlHZLIR9yYYbXa0eU4aV9rahO7ONg28JZWsh9UecQ6/Nz0HN9iTfuD0Qrx9+mmmdgiejox7x2O9o7kDvddwJ4RnfkICheYnJFBofkICheYnJFBofkICheYnJFDaLfWJyF0AzgCwVVWPSMf6AXgAqfls6wGcq6peZ7zssMpl850xzqw4c+ob4Pf+O8CIe7MLvRmEznJMTU6fvvXDbW2IEe+JA80xT2GLqXl9Bp9wNGsipo+3N3u9qzKnh5/1lHn5NcOeXTjqyjdM7ZUjnY3+0NGs16pXkh5gxP/sjGlDJmf+uwG0LaTOBFChqiMAVKR/J4R0Ido1v6ouBrCtTXgqPr1FZR6AM3OcFyEkz8T9zH+gqtYCQPr7wNylRAhJgrzf3isiMwDMyPd+CCEdI+6Zf4uIlAJA+vtW6w9Vda6qlqtqecx9EULyQFzzPw5gTxe06QAey006hJCkaHe5LhG5D8CJAEoAbAFwLYBHkSpiHQTgbQDnqGrbi4JR20pubTAPr0OjNwvPKuV4H2p6OZoznc5b5usc7O9sNHp9qhKn1FePKlP7P2dPv/jQEW804vc4Y1b/xtZG2p1Qv/baR6Y2wYh/CUeZY8bhFlNrxhxTK4L9pC3Hwaa22Tj+jbDLiq/r2sj4/HEbsaXyo4yW62r3M7+qnm9IkzLZASGkc8I7/AgJFJqfkECh+QkJFJqfkECh+QkJlM7TwLOz4M2kqjbi33PG/NyWpjvlPGs2GgCsh93ossQoKRU5T/VKZ1+/8O7g+KOjWY0uvVmTWGdLk+0npgF2qc+q3BY75c0ifNfUSvG2qR1q1jeBSfi6qVmtULfhHXNEPzk5Mr4Emd9LxzM/IYFC8xMSKDQ/IYFC8xMSKDQ/IYFC8xMSKF271Odl762b1uBo7mJyBk4jTr+0Zdf6ijDN2Z3dKXKIMVutHu+aY5asMCVg6UJby/m6dT81laMP2NfU/t3ZopWi18BzidMQ1Ht5/AgXmNpwZ5vA5yOjy/GBOeJUnORsLzN45ickUGh+QgKF5ickUGh+QgKF5ickULr21f5YV5QR74p+XKJb6qWk9+yr2x/uOs/UDinpZm/UeEaLnIrEJWPbLsj0KReNtcetsZs2o/qZ6Ek6f1hwg71BPGoqE5rtyTunftJL9rPc+MnaMnsTZyUsAKh1tDcdbYjTF9B6arz+iZUf3h0Zr23xmlDuDc/8hAQKzU9IoND8hAQKzU9IoND8hAQKzU9IoLRb6hORuwCcAWCrqh6Rjs0CcAmAuvSfXa2qT+YryU6PU87r1/gDU3twtr2E04C+djmvYYS9v0aj0rNmtV0qKxthT5rp2dfe14SJ9srsg46L1p4660pzTMvDdqlvqVNHe80o5wHAaCM+DIPNMRuc3nm9Hcs0w37O/tfpMzjEiE8xRwA9e0VP4Jq/z/vOqL3J5Mx/N4CoQvDNqjo6/RWu8QnporRrflVdDKDdRTgJIV2LbD7zXyEiVSJyl4h4naYJIZ2QuOa/DcDBSH2kqoW9IDNEZIaIVIpIZcx9EULyQCzzq+oWVd2tqi0A7gAw3vnbuaparqqZryZACMk7scwvIqWtfp0G4NXcpEMISYpMSn33ATgRQImIbARwLYATRWQ0AAWwHsCleczR5PNldolq2AnmmxEU7bIf9p8XLOp4IsP+w5S2vTnBHlf3liltHbG/qdVutstU26rfiBaqVpljVjXaveLQaJeOHho3xtR6jIkuY7Y87PQEdHjeWioNwK+dcb2NeJ1TzhvpbG+yM5W0r6N5bSOtjozj4c2AvCwy2gtfccbsTbvmV9XzI8J3ZrwHQkinhHf4ERIoND8hgULzExIoND8hgULzExIoiTbwHNz/8/jXqdElip4n2GWjnmMOi4yfNGy4OabYqvHAnYSHswZdYWoVt94fLVjlNQCofttJxC7nod5eQ2tb3YHOuOjGmXBKW0B/R3PW8lpiz1hsWmJt83POvhycUp/Xj/VpI772OmeQ16XTaWh66cW29hdnk1b+xzkNTe1EnLJtG3jmJyRQaH5CAoXmJyRQaH5CAoXmJyRQaH5CAiXRUt+gslJcdef3k9xlh6mpdxa1w7tG/PfxdubtqsYrv51tS32PjY43OGVFOOVIZz0+H+tYWfH4eGvamS9w75XvTRN0pvzNcZqdmlP3AKwaFh1/ovtSc8xPMDkyvt1JoS088xMSKDQ/IYFC8xMSKDQ/IYFC8xMSKIle7e8K1Fd7kymSxLsqPseWGqw+cnZ/OcCYsNSZcF6pqx5zxp0QHf7yTHvISxuc7RnLoQEAvHGnd3zcSxvtIbcZj6sjtRme+QkJFJqfkECh+QkJFJqfkECh+QkJFJqfkEDJZLmuoQDuATAIQAuAuap6i4j0A/AAgDKkluw6V1Xfy1+qHaMJm0yth1NGK6q2l6dqyiqjpPgbXUxphqMNdTSjwlntvFK/6PT367vD1mqcMmDPXra21eg3ebjd1hK7PoyOa4s9pi2ZnPmbAXxHVUcCOAbA5SJyGICZACpUdQSAivTvhJAuQrvmV9VaVV2R/nkHgBoAgwFMBTAv/WfzAJyZryQJIbmnQ5/5RaQMwBgAywAcqKq1QOofBICBuU6OEJI/Mja/iBQDeAjAt1U1454BIjJDRCpFpLKuri5OjoSQPJCR+UWkO1LGv1dVH06Ht4hIaVovhXFbsarOVdVyVS0fMGBALnImhOSAds0vIoLUJeQaVb2plfQ4gOnpn6cD8KZXEEI6GZnM6jsewIUAqkVkZTp2NYDrASwQkYuRagJ3Tn5SBLYZ8UZYS1MBDfqcqQ3CFlPbmWlSJFGOnm1ry56xtT7GqlbeC7/WaWl40UFHmdq0g6pMzZtTeY0x7JhJ9hirJeBrHbiK1675VfUvAMSQnfQIIZ0Z3uFHSKDQ/IQECs1PSKDQ/IQECs1PSKB0iQae/Yx4MYabYzb/8R1Te6p+iantV2znsdNbXotkz5SY4162pQNOjY5700+nHWRr52BfU+vpbHORox0/MTruTVa874Xo+LYOvEZ55ickUGh+QgKF5ickUGh+QgKF5ickUGh+QgKlS5T64lAybLCplU20OyOOqbbLgM//JHpu1pevsvN4yZb82tBqR5vvbTRBjnW0pTG2d40tTcbnTG30TPtlvMZo1rpc7X3tsqaxAbgJy03Nq1Q6y+5hgrG/OifHDW9Gx5s60GWWZ35CAoXmJyRQaH5CAoXmJyRQaH5CAiXRq/0tsHvkNRrLDwFAX2OpoyJ8YI4ZPtye9NO4Y7GpWVf0PWrmOOLpjuZ1Mh/R4TSSpyHGmCGO5iyFdd1Eexk1jHS2aVQQ9nEmcD1gXEkHADgTZ54+ztZOczY5wYg3OFWHhrOi40/d6OyoDTzzExIoND8hgULzExIoND8hgULzExIoND8hgdJuqU9EhgK4B8AgpKp1c1X1FhGZBeASfFqwulpVn/S2tQ+A/Qyt3ikb9TBKfVvxe3PMgw+cZ2pX2JL737DFiO/0Sl5xJ+EsjDkuSTpeFQWM5xIA8A1H2+xoXoO88dHhFq+Jnzcp6VxbWvtzW5ttzxfDGGOVy4tgT06r7hXdo7JbLpfrQuop/o6qrhCR3gBeEpE9L82bVdV5yISQzkoma/XVAqhN/7xDRGoA518SIaRL0KHP/CJSBmAMgGXp0BUiUiUid4nIATnOjRCSRzI2v4gUA3gIwLdVdTuA2wAcDGA0Uu8MIm8sFJEZIlIpIpV1dd79rISQJMnI/CLSHSnj36uqDwOAqm5R1d2q2gLgDhiXVlR1rqqWq2r5gAEDcpU3ISRL2jW/iAiAOwHUqOpNreKlrf5sGoBXc58eISRfZHK1/3gAFwKoFpGV6djVAM4XkdEAFMB6AJe2t6FG7MILqInUajesNcdVvxYd/+0iu2b3wLPtZRONVc7rVPybo92a431da0s9jrS1prMNwetNGBen/GbOFKx3xixwNK+YHXM5t18aM1qPNMp5AHBHVXT8Y2d2bFsyudr/FwBRkwvdmj4hpHPDO/wICRSan5BAofkJCRSan5BAofkJCZREG3hu/7geC2vvjtSqX7jXHNe4OrrksehlZ2deE8YuzqSv2VpFrkt982ypyZvNOM6I26tdxWeYow014t1j7itmOc9ryPqK8Tp+xGkIWmI8rroemafEMz8hgULzExIoND8hgULzExIoND8hgULzExIoiZb6PvpgG1a/GF3SK+ptz2AqMZowTnDWaKv4L1vrY0vY7mgWp06xtWeeirFBAJOsUhmAMWNsrcKa8Re3BLje0fo6mlXa8pp+eqVbjziNRE9ytH90tLgNWb3ZjMZaj9c7axeOOjU6/l63jDPimZ+QUKH5CQkUmp+QQKH5CQkUmp+QQKH5CQmUZEt973+MNU9Gl/SKrdlXADYaWQ5yymFTH7W1eqd5Y4OTx67F0fGlccs/DhXO7LeKmc5Aozv6frfbQ3bOcrbnNMc8/Ou2dohRnu3p7OoRY806AGjyOkYOcTTrud7ljHFKyHnBKnE6B6vamMnY4j2uNvDMT0ig0PyEBArNT0ig0PyEBArNT0igtHu1X0R6AlgMYN/03/9OVa8VkWEA7gfQD8AKABeqapO3rf49gYuMCR+vO1fZrfkjzc6i4IPG2trmFbZW41y5b4lch7gAeH3k7okO73T67X35x7a2wVlYedUNjnZKdHw/Z1LST6faWo2jPae29pa19FatPQaljuZUmGL393uv40OKjErAxx04nWfypx8BmKiqo5Bajvs0ETkGwA0AblbVEUilf3HmuyWEFJp2za8p9vxP657+UgATAfwuHZ8H4My8ZEgIyQsZvUkQkW7pFXq3AlgIYC2ABlXdM5N6I4DB+UmREJIPMjK/qu5W1dFI3Us1HtH3QEV+8hKRGSJSKSKVjXE/ExFCck6HrvaragOAPwE4BkBfEdlzwXAIgE3GmLmqWq6q5cXF2aRKCMkl7ZpfRAaISN/0z70AnAygBsAiAGen/2w6AOfObEJIZ0NUnToJABE5CqkLet2Q+mexQFV/JCLD8Wmp72UAF6jqR962RpeKPmvUBDZeta857r4bozc735ns0eBMzujrlHkaFtraTlvKPSWO5pWbYvYMNJngaM4kkj5GqW+7s4zaF75pa2dMsjVjLhMA4NZ10fFttziDnDIxnNKnu0Scl+TDRtzrTWhN1JoB6OsqzshPaLfOr6pVAD5TnVXVdUh9/ieEdEF4hx8hgULzExIoND8hgULzExIoND8hgdJuqS+nOxOpA/BW+tcS2B3WkoR57A3z2JuulscXVNUrLH5Coubfa8cilapaXpCdMw/mwTz4tp+QUKH5CQmUQpp/bgH33RrmsTfMY2/+ZvMo2Gd+Qkhh4dt+QgKlIOYXkdNE5K8iskZEvMWn8p3HehGpFpGVIlKZ4H7vEpGtIvJqq1g/EVkoIqvT3532pHnNY5aIvJM+JitF5PQE8hgqIotEpEZEVonIlel4osfEySPRYyIiPUXkRRF5JZ3HD9PxYSKyLH08HhCRHlntSFUT/UJqavBaAMMB9ADwCoDDks4jnct6ACUF2O8JSE0cfbVV7GcAZqZ/ngnghgLlMQvAfyZ8PEoBjE3/3BvAGwAOS/qYOHkkekwACIDi9M/dASxDqoHOAgDnpeO3A/hWNvspxJl/PIA1qrpOU62+7wfgNGb+20NVFwPY1iY8Fam+CUBCDVGNPBJHVWtVdUX65x1INYsZjISPiZNHomiKvDfNLYT5BwPY0Or3Qjb/VADPishLIjKjQDns4UBVrQVSL0IAAwuYyxUiUpX+WJD3jx+tEZEypPpHLEMBj0mbPICEj0kSTXMLYf6oLiOFKjkcr6pjAUwBcLmInFCgPDoTtwE4GKk1GmoBJLZUiYgUA3gIwLdVdXtS+80gj8SPiWbRNDdTCmH+jQBar89jNv/MN6q6Kf19K4BHUNjORFtEpBQA0t+3FiIJVd2SfuG1ALgDCR0TEemOlOHuVdU9ja0SPyZReRTqmKT33eGmuZlSCPMvBzAifeWyB4DzADyedBIisr+I9N7zM4BTALzqj8orjyPVCBUoYEPUPWZLMw0JHBMREQB3AqhR1ZtaSYkeEyuPpI9JYk1zk7qC2eZq5ulIXUldC+C/C5TDcKQqDa8AWJVkHgDuQ+rt48dIvRO6GEB/ABUAVqe/9ytQHr8FUA2gCinzlSaQx98h9Ra2CsDK9NfpSR8TJ49EjwmAo5BqiluF1D+aH7R6zb4IYA2ABwHsm81+eIcfIYHCO/wICRSan5BAofkJCRSan5BAofkJCRSan5BAofkJCRSan5BA+X+oAC6reFaYfAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.imshow(img.permute(1, 2, 0), cmap='gray')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAFy5JREFUeJztnVuM3dV1xn/LYxviS/ANm/GlNsbEAayUIAsRgSqqKIhWkUgeEoWHiEpRzEOQGikPjXgJL5VQlUt5qCI5BYVIuUpJGh5QmwhVokmkiAFBINycoDG+jO8YbAMxHq8+zHE7mNnfPnNm5pwh+/tJlmfOOv//Xmef/zfn8u21dmQmxpj2WDDoBIwxg8HiN6ZRLH5jGsXiN6ZRLH5jGsXiN6ZRLH5jGsXiN6ZRLH5jGmXhTA6OiNuBB4Ah4N8z8351/6VLl+aKFSumjF1yySVyrHPnzhVjf/7zn4sxtYJxaGhIjqmO7TW2cKGe8gULyn+Pe30sESHHrMV7yafG+fPne4qp6wDg7NmzxdgHPvCBYkzNn7q+QD9nam5r15+idN4TJ05w5syZrp7QnsUfEUPAvwGfAPYDT0TEI5n5fOmYFStWcPfdd08Z27p1qxzv+PHjxdjevXuLsXfeeacYW758uRxTXUgqpi7e1atXyzEvvfTSYkxd+KU/qlD/w6ouUHVhj4+Py/Oq+Ntvv12MnTlzphg7ceKEHHN0dLQY27FjRzF22WWXFWN/+tOf5JjLli0rxtQfezWmuoYAFi1aNOXtDzzwgDxuMjN5238j8MfMfCUzzwI/Au6YwfmMMX1kJuLfAOyb9Pv+zm3GmPcBMxH/VO8V3/MhMCJ2RcRIRIyot3PGmP4yE/HvBzZN+n0jcPDiO2Xm7szcmZk7ly5dOoPhjDGzyUzE/wRwdURcGRGLgc8Bj8xOWsaYuabnb/sz81xE3AP8FxNW30OZ+Qd1zKWXXso111wzZeytt96S46lvpNW39ocOHSrGalbV8PBwMfbaa68VY8eOHevpOIC1a9cWY+pbZWUbLV68WI6p3IA33nijGNu3b18xVhtXzb1yaE6fPi3HPHr0aDFW+oYcYNWqVT2PqZwLZS+qb/SVmwTaFeqWGfn8mfko8OiMszDG9B2v8DOmUSx+YxrF4jemUSx+YxrF4jemUSx+YxplRlbfdBkfHy/6xnv27JHHqqXBW7ZsKcZUVVqtpFJ57sqnVv6t8n1BV4EtWbKkp1jNE1ZrLA4cOFCMqTUUoNdJKE6ePFmMqXUHAKdOnSrG1POi1hbUSnrVGhQVU15+bW1GqSJwOmXCfuU3plEsfmMaxeI3plEsfmMaxeI3plEsfmMapa9W37lz54oNGJUlB7BhQ7lD2MqVK4sx1diyVl6r7DyVryrxXL9+vRxTNTxR1pCyeN588005piqDVQ0za81ZlF2l7LPnny/2gOXIkSNyTFV+e/Dge3rNdJWPsg9Bl1orC7HX5p5Qntuajt51367vaYz5i8LiN6ZRLH5jGsXiN6ZRLH5jGsXiN6ZR+mr1LVq0iCuuuGLKWG0DRmWnKDtK2YC1ajdVBbZmzZpiTNlNta6sH/zgB4sxZRsdPny4GKvtqaesSWXn1SrP1GNV1YLPPPNMMVbbq09ZwupY1b1XxaD3TrqqErO2v+JMNvm8gF/5jWkUi9+YRrH4jWkUi9+YRrH4jWkUi9+YRumr1Tc0NFS0jlT1HejGjbWqqxLKagFd1acqspQ1pDbxBG0vqjFVNVfNNlIbWNbsPIWqlFMNWdW8X3fddXLMbdu2FWOqoahq1lpD5auuTdXMtbaJbMn2rR03mRmJPyJGgVPAOHAuM3fO5HzGmP4xG6/8f5uZ+uXMGDPv8Gd+YxplpuJP4JcR8WRE7JrqDhGxKyJGImLk9ddfn+FwxpjZYqZv+2/OzIMRsRb4VUS8mJmPT75DZu4GdgNs27at+28jjDFzyoxe+TPzYOf/I8DPgRtnIyljzNzTs/gjYmlELL/wM3Ab8NxsJWaMmVtm8rZ/HfDzzkaEC4EfZOZ/qgPGx8eLvufy5cvlYKosVfnUykutde/du3dvMabKOGfimysvX5U1q7UFau0AwPHjx4sx5dXXypPVY1GxzZs3F2O33XabHFOtF9m+fXsxpp6X0dFROabqjtzr/NXWoNRK4LuhZ/Fn5ivAX884A2PMQLDVZ0yjWPzGNIrFb0yjWPzGNIrFb0yj9LWkNyKK5aW1bqSqi6wqY1QbHtY2Q1TdXpWFo2K10lFl56k5UvZizRaaThlot/nUOHnyZDFW6vAMsGnTJnleZZ+p7sdqjmqWsBpTPZ/qWlDHAbz11ltT3u6NOo0xVSx+YxrF4jemUSx+YxrF4jemUSx+Yxqlr1ZfZhbtlpodpWwlVbmnKtpqHYNV11tVCacsxNqmj2rTzF47ISnrEcq2Eejus52KziLqeVHnvfzyy4ux2nXS62amyiJT8wO6wlNtvKqsvlolpqqK7Ba/8hvTKBa/MY1i8RvTKBa/MY1i8RvTKBa/MY3SV6tvwYIFxcaENTtFNfhUVWnKHlObf4K2f3rdqLPWwFM1g1RzpHKtPU4VVxVtqqkqaJtLPZ9qbsfGxuSYah7UdaKq6GqWsLJ21bHKKq1VWvZaiTkZv/Ib0ygWvzGNYvEb0ygWvzGNYvEb0ygWvzGNYvEb0yhVnz8iHgI+CRzJzB2d21YBPwa2AKPAZzNTtzhlosNsqYzxzJkzOlHh/Sq/WZXIKk8ddEdc5WErz/j06dNyTOXfqrJmNX+1MY8ePVqMqdLS2uaqqtRVbcap5r22OejKlSuLMVVirKh1xFX5KnqddyjrYTr+fzev/N8Fbr/otq8Cj2Xm1cBjnd+NMe8jquLPzMeBixvY3wE83Pn5YeBTs5yXMWaO6fUz/7rMHAPo/F9sSRIRuyJiJCJG1EYNxpj+Mudf+GXm7szcmZk7a2ukjTH9o1fxH46IYYDO/0dmLyVjTD/oVfyPAHd1fr4L+MXspGOM6RfdWH0/BG4F1kTEfuBrwP3ATyLiC8CrwGdmnEilG6myeJTVpyyTmWwOunr16p7Oqzb/BDh48GAx1mu5b63rrzqvepxbt26V51Xzt2fPnmLsqquuKsZqHYNVXHUMVp19a1ap6iisbELVAbrWcbm0mWnNCp1MVfyZeWch9PGuRzHGzDu8ws+YRrH4jWkUi9+YRrH4jWkUi9+YRulr997z588XLQxVfVeLq464qpJLVZ2B7sqqKvfUMuZaJ91e7TwVq9moGzZsKMZKlhLUu9qqx6osz4997GPF2MaNG+WYzzzzTDGmKt567dQM2l5U1ZbKeqxVuZbsxelYfX7lN6ZRLH5jGsXiN6ZRLH5jGsXiN6ZRLH5jGqWvVl9mFi2pWsPCXlFWX83CKTUbBV1JeOzYsWJMVXKBtgnVppmqCqxmySkLTFWsqQaUAPv37y/GlA2oGmJecsklckyVr3q+1RwpuxP0NabmQDWB7fU5q1WqTsav/MY0isVvTKNY/MY0isVvTKNY/MY0isVvTKNY/MY0Sl99/qGhoWKZbM1zVyWXyttVvnmt/FGtPTh06FAx9uKLLxZjBw4ckGP2Wgb7zjvvFGO17rO9etG150ytS1B+tHrOamOq9ReqTFatH1iyZIkcU10nKqbWD6g1JlAu4a6tg5iMX/mNaRSL35hGsfiNaRSL35hGsfiNaRSL35hG6WajzoeATwJHMnNH57b7gC8CF2o6783MR6uDLVzImjVrpowpiwZ0yaqKqRLaWhmxst327dtXjKnNNmtdipVtqVCdiGv2j5o/ZS+q0lvQj1Vt4qnGrJXXKpSlqbox18pkVfdoZU2WtADaYgXdrblbunnl/y5w+xS3fyszr+/8qwrfGDO/qIo/Mx8H9L7Sxpj3HTP5zH9PRPw+Ih6KiJWzlpExpi/0Kv5vA1cB1wNjwDdKd4yIXRExEhEj6rOcMaa/9CT+zDycmeOZeR74DnCjuO/uzNyZmTtXrVrVa57GmFmmJ/FHxPCkXz8NPDc76Rhj+kU3Vt8PgVuBNRGxH/gacGtEXA8kMArc3c1gb7/9Nnv27JkypjYtBL2BparOUxs31uxFtVGnssdU9ZiqTgRt8ajqMvVYahuSqopANe+qay1om0vNg3qc6jkBWL58eTGm7Dw1f2ojTtBWoJp7dc3X9FCKT8cqroo/M++c4uYHux7BGDMv8Qo/YxrF4jemUSx+YxrF4jemUSx+YxrF4jemUfravff06dP89re/nTJW8zUVyk8eGRkpxmqlmjfddFMxtmXLlmJMde+tee7Kp+7VM64tq1bnVfnUPHdV6qq8c9W5tlbKqtYPbN++vRhTHY5Xr14tx1TXn5o/tR6k1nG5tvagG/zKb0yjWPzGNIrFb0yjWPzGNIrFb0yjWPzGNEpfrb6zZ8/yyiuvTBmrldeqDrTKalEdZFWJJ8CHP/zhYkx1kVUdg2tlsCqnXju21uZWWXKqy26tOYt6rOvWrSvG1q9fX4ypjsqgrT51Da1YsaIYq3XSVdeYsuTUvNfsxZK1W9vIdDJ+5TemUSx+YxrF4jemUSx+YxrF4jemUSx+Yxqlr1bf+Ph4sZKpVmGnyMxiTFlnmzdvlufttQpMPRZV5QXaAlNjqq6ttblVFpjKt2ZHqWrBbdu2FWNqDpSNCtpePHbsWDGm7NCVK/WGVKp7tDqvsm6vvPJKOWbpObXVZ4ypYvEb0ygWvzGNYvEb0ygWvzGNYvEb0yjdbNS5CfgecAVwHtidmQ9ExCrgx8AWJjbr/GxmvlY5V7GSqbaBpULZO1dffXUx9pGPfESe9/z588WYqi5TFX/Dw8PFGGhbSW2oqSrsXn31VTmmsoc+9KEPFWOqKg10tduaNWvksSUOHDgg48peVPOn7M5aNaW6Fnq1AdVGsGrMWgXnZLpR3DngK5l5DXAT8KWIuBb4KvBYZl4NPNb53RjzPqEq/swcy8ynOj+fAl4ANgB3AA937vYw8Km5StIYM/tM6712RGwBPgr8DliXmWMw8QcCKDdbN8bMO7oWf0QsA34KfDkzdTuVdx+3KyJGImJEff4xxvSXrsQfEYuYEP73M/NnnZsPR8RwJz4MHJnq2MzcnZk7M3Nn7QsiY0z/qIo/JpqQPQi8kJnfnBR6BLir8/NdwC9mPz1jzFzRTQnQzcDngWcj4unObfcC9wM/iYgvAK8Cn5mbFI0xc0FV/Jn5a6DUgvTj0xlswYIF1e61JRYtWlSMqdJS5WGrjq3Qe1db1X1WrQEAXXaqSpcvv/zyYqzUMbmbY2+44YZibO/evfK8iiVLlhRjR45M+QkS0F49wJtvvlmMqbJndZ3UyojVOgC1BkCt6Th58qQc89ChQ1PeXpufyXiFnzGNYvEb0ygWvzGNYvEb0ygWvzGNYvEb0yh97d4L5ZLDWkmvsgiVDagskxdffFGOqbr7qhJQtTmjKhMG2L9/fzGmymDVmDXbaMeOHcWYmtt9+/bJ86oNLksbTQK89NJLxVjNypqO1TUZZeedOHFCHqvmSMXUNV8bs6QjZQe/Z/yu72mM+YvC4jemUSx+YxrF4jemUSx+YxrF4jemUfpq9UVEsXqqZlGoyillj1122WXFWM12O3r0aDGmrL7R0dFirLRR6QWUBabGVF1rax1d1WacY2NjxZiqQARtTapKw9/85jfFWG0jSrUxqzpWXV+qghN0l2I196oatWbPlixNW33GmCoWvzGNYvEb0ygWvzGNYvEb0ygWvzGN0nerr1TlpBov1uJDQ0PF2C233FKM1SwcZYGpRpzqvMqKAti0aVMx9vLLLxdjr71W3iO1Nqba5FPZobXzqs0mlbWm5lY1GwVth6prSNmWan5A27PKslMVf7U9Lko2tGpSejF+5TemUSx+YxrF4jemUSx+YxrF4jemUSx+YxrF4jemUao+f0RsAr4HXAGcB3Zn5gMRcR/wReCC4XhvZj6qzpWZxVJE5dWDLn/cunVrTzF1zlpOyo9XMVX+CXrz0D179hRjqvxYrVcAXQaqNpqsee4qJ9WNefv27cVYbaPTs2fPFmOnT58uxtQaALUmAXSXYnUtKC+/1s26dG2qLs4X080in3PAVzLzqYhYDjwZEb/qxL6VmV/vejRjzLyhmy26x4Cxzs+nIuIFYMNcJ2aMmVum9Zk/IrYAHwV+17npnoj4fUQ8FBFTbjYeEbsiYiQiRmpveY0x/aNr8UfEMuCnwJcz8w3g28BVwPVMvDP4xlTHZebuzNyZmTtVqyljTH/pSvwRsYgJ4X8/M38GkJmHM3M8M88D3wFunLs0jTGzTVX8MfH14YPAC5n5zUm3D0+626eB52Y/PWPMXNHNt/03A58Hno2Ipzu33QvcGRHXAwmMAndXB1u4sNjRVdkloLvwXnfddcWYKsdU3WUBVq6c8msMQFtDqpRVWY+g7ShVAqosLlXmCjrfWhdZhbKr1ONUFljteyM1D6p7r7q+Dh8+LMdcv359MbZly5ZiTNmo6rkGuPbaa6e8fTofrbv5tv/XwFTmofT0jTHzG6/wM6ZRLH5jGsXiN6ZRLH5jGsXiN6ZR+tq9d/HixUXrY+3atfJYZWGobrnK5qptYKlsrr179xZjM7FwVLdcVbGl7KZapZey1oaHh4sxZRGCfl5Ut2FFrTutqlBUVt+SJUuKsdp1oh6nqghUtqWyLKFsU9c2Mp2MX/mNaRSL35hGsfiNaRSL35hGsfiNaRSL35hG6avVt3DhwmLTx9rGhCp+/PjxYkzZUcpWA20blRqRgrbWahaXahraa4Xdxo0bZfzMmTPFmHosNTtKWWSq+efrr79ejNU2dFXPqbqGlCVXq4pU86DmQM2tqihVOdUa4U7Gr/zGNIrFb0yjWPzGNIrFb0yjWPzGNIrFb0yjWPzGNEpfff6hoSGWLVs2ZazmfysvWnnuag1AzRNVm0KuW7euGFPloeo40D628pvHxsaKsdqGpMobP3XqVDGm1kGAnnvVLVedt1aerPx69TjVtVCbP1VmrGJqfUBt3Utpo9PaBp/vum/X9zTG/EVh8RvTKBa/MY1i8RvTKBa/MY1i8RvTKFGza2Z1sIijwOS2t2uAY31LoI7z0cy3fGD+5TTofDZnZrleehJ9Ff97Bo8YycydA0vgIpyPZr7lA/Mvp/mWj8Jv+41pFIvfmEYZtPh3D3j8i3E+mvmWD8y/nOZbPkUG+pnfGDM4Bv3Kb4wZEAMRf0TcHhEvRcQfI+Krg8jhonxGI+LZiHg6IkYGlMNDEXEkIp6bdNuqiPhVROzp/K9bus59PvdFxIHOPD0dEX/fx3w2RcR/R8QLEfGHiPjHzu0DmSORz8DmaLr0/W1/RAwBLwOfAPYDTwB3ZubzfU3k3TmNAjszc2D+bET8DXAa+F5m7ujc9i/Aicy8v/NHcmVm/tMA87kPOJ2ZX+9HDhflMwwMZ+ZTEbEceBL4FPAPDGCORD6fZUBzNF0G8cp/I/DHzHwlM88CPwLuGEAe84rMfBw4cdHNdwAPd35+mImLa5D5DIzMHMvMpzo/nwJeADYwoDkS+bxvGIT4NwD7Jv2+n8FPWgK/jIgnI2LXgHOZzLrMHIOJiw1YO+B8AO6JiN93Phb07WPIZCJiC/BR4HfMgzm6KB+YB3PUDYMQ/1StWAZtOdycmTcAfwd8qfOW17yXbwNXAdcDY8A3+p1ARCwDfgp8OTPf6Pf4XeQz8DnqlkGIfz+wadLvG4GDA8jj/8jMg53/jwA/Z+KjyXzgcOez5YXPmEcGmUxmHs7M8cw8D3yHPs9TRCxiQmjfz8yfdW4e2BxNlc+g52g6DEL8TwBXR8SVEbEY+BzwyADyACAilna+sCEilgK3Ac/po/rGI8BdnZ/vAn4xwFwuiOsCn6aP8xQTzfseBF7IzG9OCg1kjkr5DHKOpstAFvl07I9/BYaAhzLzn/uexP/nspWJV3uYaGj6g0HkExE/BG5loirsMPA14D+AnwB/BbwKfCYz+/IlXCGfW5l4O5vAKHD3hc/bfcjnFuB/gGeBC50472Xic3bf50jkcycDmqPp4hV+xjSKV/gZ0ygWvzGNYvEb0ygWvzGNYvEb0ygWvzGNYvEb0ygWvzGN8r/Yen8l9w20CQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.imshow(output[0, 0].detach(), cmap='gray')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1, 3, 32, 32]), torch.Size([1, 1, 32, 32]))" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "conv = nn.Conv2d(3, 1, kernel_size=3, padding=1) # <1>\n", "output = conv(img.unsqueeze(0))\n", "img.unsqueeze(0).shape, output.shape" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "with torch.no_grad():\n", " conv.bias.zero_()\n", " \n", "with torch.no_grad():\n", " conv.weight.fill_(1.0 / 9.0)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAF79JREFUeJztnV+MXdV1xn8L4z/YGBvb2Bjj1BD5IVFUSDRCkaiiNGkjGlUikZooeYh4QHFUgdRI6QOiUkOlPiRVkygPVSqnoJAqhND8UVAVtUEoFcoLiUMJfwIFgt1gPLGNwdgQAni8+nCv1WFy1zd3zsyca7K/nzSaO2fdfc4++5xv7r37u2vtyEyMMe1xzqQ7YIyZDBa/MY1i8RvTKBa/MY1i8RvTKBa/MY1i8RvTKBa/MY1i8RvTKOcupnFEXAN8GVgB/Etmfk49f926dblx48aRMfVNw5mZmZHbzzmn/t+1YsWKMta13bnnjh4utT91XqdPn+4U65OIWNJY1/11Hcfq3lH76xrrej2rmGpTjdWJEyd45ZVX6oGcRWfxR8QK4J+APwUOAj+NiLsz8xdVm40bN3LDDTeMjP32t78tj/XSSy+N3L5mzZqyzQUXXFDG1q1bV8Y2bNhQxjZt2rTg/b3++utlrDovgN/85jdlbKlR//BWrlxZxtQ/vVWrVo3cvnr16k7HOnXqVBk7ceJEGavG+NVXXy3bVP8wAF577bUypq6ZilX3/iuvvFK2qV6I7rjjjrLNXBbztv8q4KnMfDozXwPuBK5dxP6MMT2yGPHvAJ6Z9ffB4TZjzJuAxYh/1OeK3/lAFBF7ImJfROx7+eWXF3E4Y8xSshjxHwR2zvr7UuDQ3Cdl5t7MnMrMKfXZ2BjTL4sR/0+B3RFxWUSsAj4G3L003TLGLDedZ/sz81RE3Aj8JwOr77bMfFS1iYhyllLN9Faz+uedd96C24C2a9TsfDW7ff7553c6VjUWoMdDWUDV8dSMvhpH9W6tmtGHelZfzfar8VBukJqdr2b71RiqfnS189R9VTkBysWoxn4hFvGifP7M/AHwg8XswxgzGfwNP2MaxeI3plEsfmMaxeI3plEsfmMaZVGz/V2obBmV0aWsqIquSTPKmjt27NjI7ZdddlnZpkoGAp0kor4NqWyvytJTdqSyttTYq3aVLaoSdNQ90DWx5+jRoyO3q6QZlRSm7g+VLKRi1f2o7tMumYBz8Su/MY1i8RvTKBa/MY1i8RvTKBa/MY3S62z/zMwMJ0+eHBlTSSLVTHXXBB2VCKJm2asZW1XOqjpf0P1Xs9GKtWvXjtyuEnuqNqATpNQMfOVkdF0VWs1iq5n0qh+qTdd7p8uMPtTOlHKsKmdkIePrV35jGsXiN6ZRLH5jGsXiN6ZRLH5jGsXiN6ZRerf6XnjhhZExZUVVdk2XxBLQySpdln568cUXyzbKelGJLKofqv/VOKrxUCjbSyUYVbaXus7KKlMWbBdLTPVDWZhdk3e6jJWyFdW9My5+5TemUSx+YxrF4jemUSx+YxrF4jemUSx+YxplUVZfRBwATgIzwKnMnFLPn5mZKW0ZZWtUVp/KOFM24IYNG8qYqrlXLTWlbCN1XspGUxZhl2xGZZWpGnhda+5V1qIajyNHjpSxZ555pozt37+/jFXHU/eHyqhU/Vd2nhrHCtXHKrYQC3ApfP4/zsznlmA/xpge8dt+YxplseJP4IcR8bOI2LMUHTLG9MNi3/ZfnZmHImIrcE9EPJ6Z981+wvCfwh7Qyz0bY/plUa/8mXlo+PsI8D3gqhHP2ZuZU5k5pdZmN8b0S2fxR8S6iFh/5jHwAeCRpeqYMWZ5Wczb/m3A94bWwrnAHZn5H6pBZpa2nbJJKrrYJ6Az5lQxy8rSU8VHV65cWcaULaPsJpXFVll9aqyUValsRZVdWJ23WqLs0KFDZezRRx/t1K76qHnhhReWbRTqmikbUMWq+1FdF2UDjkvnPWTm08AVi+6BMWYi2OozplEsfmMaxeI3plEsfmMaxeI3plF6LeCZmaXloSyg6stBKitOWWVq3bRqXUCobS9lh6nMQ2U5dikkCrU9tBQFHxdC1X9l6arsQmVvdrExlT2rYupaq36o+7G6j1XR1Sq2kOvsV35jGsXiN6ZRLH5jGsXiN6ZRLH5jGqX32f5qaSI1m1vN9qvZUDVbrmb71Wx01e75558v26gaBiqmZpxVanS1T1W38IILLuh0LDXzXTkSanzVNduyZUsZ27ZtWxm79NJLR25X56X6eOzYsU7tutQFVPfAUuBXfmMaxeI3plEsfmMaxeI3plEsfmMaxeI3plF6t/qq5Adl9VX14LokuIBOwFD2YZWAcfLkybJNl4QOgIsuuqiMKWuuqk2n2qhafIrKtoV6jJUtqpK7lFW5devWMrZr166R29W9c/z48TKm+qiSuJStW/VF7a+yKp3YY4yZF4vfmEax+I1pFIvfmEax+I1pFIvfmEaZ1+qLiNuAPweOZOY7hts2Ad8CdgEHgI9m5gvz7Utl9SkLRVlpFV0XBe1SO08tyaWsQ7XkkrKGNm7cWMbWr18/cnvX8VBWpbLtqnp8amktVcNP1VaszlnF1HXpYrHNF1OZpJXl26XuoroX5zLOK//XgGvmbLsJuDczdwP3Dv82xryJmFf8mXkfMPdf/LXA7cPHtwMfWuJ+GWOWma6f+bdl5jTA8Hf9FStjzFnJsn+9NyL2AHug++dOY8zS0/WV/3BEbAcY/j5SPTEz92bmVGZOLXdZImPM+HQV/93AdcPH1wHfX5ruGGP6Yhyr75vAe4EtEXEQ+CzwOeCuiLge+BXwkXEOlpll0UdlsVUZYiorTmX1dS38WfVdHUvZcps2bSpjVXYewNq1a8tYlaGnzlllVE5PT5exgwcPlrEqM05lzKlxVJmHKkuzstiUddglaxJgx44dZUxlEVaFP1WbqpDoQj5azyv+zPx4EXr/2Ecxxpx1+Bt+xjSKxW9Mo1j8xjSKxW9Mo1j8xjRKrwU8I6LMZFOZVJUtozLflLWlrCG13lpl9ansPHVeav25rgU8q+MpW1TZbypz7+jRo2WsytBTGZrqvBRdrrUaD4WyHJV1q9pVfVH36YEDB0ZuX+qsPmPM7yEWvzGNYvEb0ygWvzGNYvEb0ygWvzGN0rvVV2VuKbtsIfbFGbpkCUJt56mYykZTGXhq/bmua+tVNRO6Zswp+0rZZeq8K1RGmrJ11ThW46EKk7788stlTJ2zKripallU96q6F9X9PS5+5TemUSx+YxrF4jemUSx+YxrF4jemUXqd7VeoGfguSRjKPehKNRut6sGpWWo1A6xmjrvUIFQzx2q2/5JLLiljmzdvLmMqWahCzcCrc+7iEqjls6ol5eaLvfBCvWKdWo6uumbKoakS0BbiAviV35hGsfiNaRSL35hGsfiNaRSL35hGsfiNaZRxluu6Dfhz4EhmvmO47Rbgk8CZIm43Z+YP5tvXOeecUyZ8qESWyq5RySOqdp5aBklZOZWlpOrtKatP2TJdagmqfSq7VI2HskzXr19fxqo+qmumEmrUdVG2XZeahl2TzJRdrWoXVpaeGt/KJlYW8VzGeeX/GnDNiO1fyswrhz/zCt8Yc3Yxr/gz8z6gLuFqjHlTspjP/DdGxEMRcVtE1EuXGmPOSrqK/yvAW4ErgWngC9UTI2JPROyLiH3qc5sxpl86iT8zD2fmTGaeBr4KXCWeuzczpzJzaiFrhxtjlpdO4o+I7bP+/DDwyNJ0xxjTF+NYfd8E3gtsiYiDwGeB90bElUACB4BPjXOwVatW8Za3vGVkTGWPVZaHsgdVpp2y0Z577rkyVmWdqXc06qOOWgpLZbipbK8qe0z1Q1lUyjpStldlpSnL66WXXipjKqvvyJEjZawaD5Vlp1BWpeqjilX399atWxfcRt0bc5lX/Jn58RGbbx37CMaYsxJ/w8+YRrH4jWkUi9+YRrH4jWkUi9+YRum1gOd5553HFVdcMTK2ZcuWsl1lr6iMOWXJHDt2rIw98cQTZezJJ58cuf3EiRNlG2WxqcKZKitRxapsuq5ZccpyVHZZZfWpsVcWrMo8VJZpdd6qH2o8lIXctZBrdd7Kyq7OS2V8zsWv/MY0isVvTKNY/MY0isVvTKNY/MY0isVvTKP0avWtWbOG3bt3j4xV26EuqKgymLoWYdy/f38Zq+whZTWpPirLTvVxw4YNZayyD5VV9uKLL5YxleWo2lU2oCrSqcZKjYdaF7CyvlR2oVpzT42HsuZUrBoTZdtVsYVkK/qV35hGsfiNaRSL35hGsfiNaRSL35hG6XW2/9xzz2Xz5s0jY9u2bSvbVTXmVO05hZr5VjO9v/71r0duV7P9Xeu6qaW8VBLUhReOXkJB7U/Nlh8+fLiMqYSmKllFjf3GjRvLmLrWKlYdTyVVqVl2lSClrqdqV7kman+Vm7UQTfiV35hGsfiNaRSL35hGsfiNaRSL35hGsfiNaZRxluvaCXwduBg4DezNzC9HxCbgW8AuBkt2fTQza5/s//e34E5W9eBUfTlleSg7T9V2qywxtcyUsthUkovap0ouUTZghUpyOXr0aBlT/a8STFSdu4svvriMVctTgV4urUL1o7JL50NZc2qsKttOaaWLjuYyziv/KeAzmfk24N3ADRHxduAm4N7M3A3cO/zbGPMmYV7xZ+Z0Zj4wfHwSeAzYAVwL3D582u3Ah5ark8aYpWdBn/kjYhfwTuB+YFtmTsPgHwRQLylqjDnrGFv8EXE+8B3g05lZf6/zd9vtiYh9EbFPfdY2xvTLWOKPiJUMhP+NzPzucPPhiNg+jG8HRi6Snpl7M3MqM6e6TqQYY5aeecUfg2nFW4HHMvOLs0J3A9cNH18HfH/pu2eMWS7Gyeq7GvgE8HBEPDjcdjPwOeCuiLge+BXwkfl2lJmlBVfZeVDbRmpZJWWtqI8fKqOrqp2nlpnqWitOWX2qTltlfyobSo2j6r/K0KtQtpzK6quyQUHXx6v6qOonqneo6lhqjFV2ZGUHd8kIVW3mMq/4M/PHQGUqvn/sIxljzir8DT9jGsXiN6ZRLH5jGsXiN6ZRLH5jGqXXAp5QWy8qC6+yQlSbrhl/69atK2OXXHLJyO2rVq0q2yirTNleqiiostgqq1LZkSqmrKO1a9eWserclGWnsvp27txZxpQ1V1m+avkvtdSbOme1T1XAs7Ju1XXuYrPOxa/8xjSKxW9Mo1j8xjSKxW9Mo1j8xjSKxW9Mo/Rq9amsvi721YoVK8o2yqJSxQ9VuyoTTGWjqUKRVZYg6CKd1dpuUGdHKstRjWPX7LfqvNU5q/Xz1P2hshKrduqclfWpsi2V5avOrbJFlT2o+jEufuU3plEsfmMaxeI3plEsfmMaxeI3plF6n+2vZl9VDb8qpmrZnThRVxdXMTXLXrVTs8NqBli5BMqRUDPH1ay+mjnukqAD2smoYippRjkSBw8eLGNdEnHUdVGOj1rOTY3Vpk2byljlgKglyipNLGQZL7/yG9MoFr8xjWLxG9MoFr8xjWLxG9MoFr8xjTKv1RcRO4GvAxcDp4G9mfnliLgF+CRwdPjUmzPzB2pfp0+fLmvrKfutsoDUcldPPfVUGdu/f38Ze/bZZ8vY0aNHR25XiSXKelH1ApXdpGrFVRabGitllSk7r4tVqeonKptV1TRUY1VZbMqWU0u2VfUkQY9jVf8R4PLLLx+5XdmDS8E4Pv8p4DOZ+UBErAd+FhH3DGNfysx/XL7uGWOWi3HW6psGpoePT0bEY8CO5e6YMWZ5WdBn/ojYBbwTuH+46caIeCgibouIOrnbGHPWMbb4I+J84DvApzPzBPAV4K3AlQzeGXyhaLcnIvZFxD5VhMIY0y9jiT8iVjIQ/jcy87sAmXk4M2cy8zTwVeCqUW0zc29mTmXmlKriYozpl3nFH4Np21uBxzLzi7O2b5/1tA8Djyx994wxy8U4s/1XA58AHo6IB4fbbgY+HhFXAgkcAD41345Onz5dWnqHDx8u21WZVNPT02Wbxx9/vIypduqjSZVFqDLmVLaispSUzaOy8KradMoOUxaVquGnbMBqrKrls0CPfdc6g1VM9UNdT9VO9VFZnNU7YrW/anzVPTWXcWb7fwyMMm2lp2+MObvxN/yMaRSL35hGsfiNaRSL35hGsfiNaZReC3ieOnWK48ePl7GKKqPr0KFDC24DulCkysKrrC1lNVVZjKCLSKovRKlYtYyT6qPKVFN2kyokWlmLyvpUFpuyI1U/qvNW95vKMFX3lUKNf3Ufqz5W46uWNfudfYz9TGPM7xUWvzGNYvEb0ygWvzGNYvEb0ygWvzGN0qvVNzMzUxaSrCwqqAtkKitEFcdU7dT6f5W9otqoTDVlsSnLUVmElR25devWso3qf5djQV3cU2UrqnXwVB+VDVgVO1Xn1fWaqftKredYnXeX4q+qf3PxK78xjWLxG9MoFr8xjWLxG9MoFr8xjWLxG9MovVp9mdmp8GBlsSmrSWWjqbXuVMZflZHW1bJTGW5q/T+V/Vad9+bNm8s2auy7FvesYqrNli1bypjKjlRjVVl66rooO1JdM3XvqGy7yv5WY19pwlafMWZeLH5jGsXiN6ZRLH5jGsXiN6ZR5p3tj4g1wH3A6uHzv52Zn42Iy4A7gU3AA8AnMrPOvjhzwGIGU83OVzOzamZTzeirmm9q5riacVZJGwuZfZ2NShJRs8rV7LZarksluailwdQ4VuOv6g+uXr26jKnZeZWIUyWMqWumksIuuuiiMqbGSjkqlROgHI7K8Vnq2f5Xgfdl5hUMluO+JiLeDXwe+FJm7gZeAK4f+6jGmIkzr/hzwJl/nyuHPwm8D/j2cPvtwIeWpYfGmGVhrM/8EbFiuELvEeAe4JfA8cw88970ILBjebpojFkOxhJ/Zs5k5pXApcBVwNtGPW1U24jYExH7ImKfWt7YGNMvC5rtz8zjwH8B7wY2RsSZ2btLgZErD2Tm3sycyswpNZFijOmXecUfERdFxMbh4/OAPwEeA34E/MXwadcB31+uThpjlp5xEnu2A7dHxAoG/yzuysx/j4hfAHdGxN8D/w3cOs4BleVRUVkhyg5TiRSqD8pyrGxK9Y5GJbIolO2lbMxqTFQbdSxlAyqqBBiVlNTVnlUJXl2sVnUPqGutkn662MFd7gFlic5lXvFn5kPAO0dsf5rB539jzJsQf8PPmEax+I1pFIvfmEax+I1pFIvfmEaJLtZb54NFHAX+d/jnFuC53g5e4368EffjjbzZ+vEHmVmnHs6iV/G/4cAR+zJzaiIHdz/cD/fDb/uNaRWL35hGmaT4907w2LNxP96I+/FGfm/7MbHP/MaYyeK3/cY0ykTEHxHXRMT/RMRTEXHTJPow7MeBiHg4Ih6MiH09Hve2iDgSEY/M2rYpIu6JiCeHvy+cUD9uiYhnh2PyYER8sId+7IyIH0XEYxHxaET81XB7r2Mi+tHrmETEmoj4SUT8fNiPvxtuvywi7h+Ox7ciolvK5Rkys9cfYAWDMmCXA6uAnwNv77sfw74cALZM4LjvAd4FPDJr2z8ANw0f3wR8fkL9uAX4657HYzvwruHj9cATwNv7HhPRj17HBAjg/OHjlcD9DAro3AV8bLj9n4G/XMxxJvHKfxXwVGY+nYNS33cC106gHxMjM+8Dnp+z+VoGhVChp4KoRT96JzOnM/OB4eOTDIrF7KDnMRH96JUcsOxFcych/h3AM7P+nmTxzwR+GBE/i4g9E+rDGbZl5jQMbkJg6wT7cmNEPDT8WLDsHz9mExG7GNSPuJ8JjsmcfkDPY9JH0dxJiH9UqZFJWQ5XZ+a7gD8DboiI90yoH2cTXwHeymCNhmngC30dOCLOB74DfDozT/R13DH60fuY5CKK5o7LJMR/ENg56++y+Odyk5mHhr+PAN9jspWJDkfEdoDh7yOT6ERmHh7eeKeBr9LTmETESgaC+0Zmfne4ufcxGdWPSY3J8NgLLpo7LpMQ/0+B3cOZy1XAx4C7++5ERKyLiPVnHgMfAB7RrZaVuxkUQoUJFkQ9I7YhH6aHMYlB4blbgccy84uzQr2OSdWPvsekt6K5fc1gzpnN/CCDmdRfAn8zoT5czsBp+DnwaJ/9AL7J4O3j6wzeCV0PbAbuBZ4c/t40oX78K/Aw8BAD8W3voR9/xOAt7EPAg8OfD/Y9JqIfvY4J8IcMiuI+xOAfzd/Oumd/AjwF/BuwejHH8Tf8jGkUf8PPmEax+I1pFIvfmEax+I1pFIvfmEax+I1pFIvfmEax+I1plP8DMQi2q65RHfEAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "output = conv(img.unsqueeze(0))\n", "plt.imshow(output[0, 0].detach(), cmap='gray')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "conv = nn.Conv2d(3, 1, kernel_size=3, padding=1)\n", "\n", "with torch.no_grad():\n", " conv.weight[:] = torch.tensor([[-1.0, 0.0, 1.0],\n", " [-1.0, 0.0, 1.0],\n", " [-1.0, 0.0, 1.0]])\n", " conv.bias.zero_()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAGLdJREFUeJztnV2snNV1hp+FweAfgjnYxsfYgHFIBImKYx2hSFQoTdqIRpFIpCZKLiIuUBxVQSpSeoGo1IDUi6RqEuWiSuUUFFKlIWl+FFShNgilQrkhMZS/YFoTY4zrf8DY4d/26sWMpcNh1nvGc3y+Mez3kY7OnL1mf9+ePd86M7PfedeOzMQY0x5njHsAxpjx4OQ3plGc/MY0ipPfmEZx8hvTKE5+YxrFyW9Mozj5jWkUJ78xjXLmXDpHxHXAt4EFwD9n5tfU/ZcsWZITExMDY2+88UbZb/HixQPbzzij/t91/PjxMhYRZUwds+Lo0aNlTD0uNY4zzxztqam+sanOpWKjfgO0mkd1rmPHjo0UU891NQ41v+pc8/GN2GpO1LnOOuusge0HDx7kyJEj9SRPY+Tkj4gFwD8CfwbsAn4bEfdk5pNVn4mJCW6++eaBseeee64818aNGwe2V/8UAF555ZUytnDhwjJ2zjnnlLHqQnr++efLPs8++2wZUxfg8uXLy5ii+kekHnN1IQG8+eabZUxdnIsWLRrYrv65vvzyy2Xs8OHDI8WWLl06sL16EQL4wx/+UMbUfIz6j+3ss88e2K5eVFasWDGw/fbbby/7zGQub/uvBp7OzO2Z+QZwN3D9HI5njOmQuST/RcD0l+td/TZjzDuAuST/oPc4b3sfGBGbImJLRGxRb+uMMd0yl+TfBayd9vcaYPfMO2Xm5sycysypJUuWzOF0xphTyVyS/7fA5RGxLiIWAp8D7jk1wzLGzDcjr/Zn5tGIuAn4T3pS352Z+TvVJyLKlWW1slmtiquV+QMHDpQx9fFj7dq1ZWzZsmUD219//fWyj2JUSUnNVbXirOZKrfarmHrc1WMbVZ5V6o1SilavXj2w/bLLLiv7vPbaa2VMKQsLFiwoY+pxV/Oo5rdSMU5Gqp6Tzp+Z9wL3zuUYxpjx4G/4GdMoTn5jGsXJb0yjOPmNaRQnvzGNMqfV/pPljDPOKM04SmI7cuTIwHZlslBuupdeeqmMKalk5cqVA9svvvjiss8LL7xQxpTp58UXXyxj73nPe8rYueeeO7C9koZAG4yU7KXGWEmO55133kjjUOzZs6eMVfLsqlWryj5Kztu/f38ZU1KfQl2rFaM4Xd9235M+qzHmXYGT35hGcfIb0yhOfmMaxclvTKN0utqv+MAHPlDG9u7dO7BdmXdU2SplIKnOBXV5J2WaUaiVb3XMquwT1Kv6yk6t5uPQoUMjxSrVQRmFlLKgVtJHUX3UONS5Rq0lqKgedzWHUD/PJ6M4+JXfmEZx8hvTKE5+YxrFyW9Mozj5jWkUJ78xjdKp1Hfs2LFSHnrve99b9lM7oVRUJhzQBpKnn366jO3atWtg+yWXXFL2UTKUkpvOP//8MjbKlmKqJqCSr1S9Q1VjrtpxSElRyiikzrVmzZoyVu1so8xdKqaeM/W8KPNOdcz5rnbtV35jGsXJb0yjOPmNaRQnvzGN4uQ3plGc/MY0ypykvojYARwBjgFHM3NK3f/o0aOl1KeknEoKUe42JQ9OTk6WMUU1RuUEVPKVGn9Vow301lWV81DJm0p+U/2UHFm5KpUT86mnnipjaozr168vY1WtvmqeQEt2ymmnXH1qi7XqOlDO1EqeVZLuTE6Fzv8nmXnwFBzHGNMhfttvTKPMNfkT+GVEPBQRm07FgIwx3TDXt/3XZObuiFgJ3BcRT2XmA9Pv0P+nsAl0vXljTLfM6ZU/M3f3f+8Hfg5cPeA+mzNzKjOn1CKWMaZbRk7+iFgSEeeeuA18HHjiVA3MGDO/zOVt/4XAz/uS2pnAv2bmf6gOx48fL2Uq5X6r5CYl51VbfAFMTEyUMeUGrOQh5fRSH3UOHqxFEhVTElAl9SgZSo1fOcuUJFY9N2q7q23btpWxyp0HcO2115axav6Vy27RokVlTElpyg2o5r+aYyWzVvnSidSXmduBq0btb4wZL5b6jGkUJ78xjeLkN6ZRnPzGNIqT35hG6XyvvkryUFJI5ehSfZR7TO2Dd+GFF5ax6nxKXlFSn5KGlOtMSZWVTKVko/nYf67qpx6Xil1xxRVlTM1x9djUNaBk4sOHD5cxVexUybPVczbqfpPD4ld+YxrFyW9Mozj5jWkUJ78xjeLkN6ZROl3tz8yylplavaxWPZcuXVr2eeGFF8qYqqunVoFfe+21ge1qJV2NUW3zpWq+KUPTKPXgRlUd1HNWjaOaQ9BzPzVVl4es6vQBbN26dWC7qglYbcsG2pik5lgZpJSRqGIU5WkmfuU3plGc/MY0ipPfmEZx8hvTKE5+YxrFyW9Mo3Qu9VW1x5SkVMlGauskJa2o+m0qVkl6SipTMuCaNWvKmKqrp0xLldSjtkN79dVXy5jaGkzNVWU+UhKmMlVt3LixjCmJsDLiqK3SlLFHzYd6ztT5qlqIqq5l9Xxa6jPGzIqT35hGcfIb0yhOfmMaxclvTKM4+Y1plFmlvoi4E/gksD8zP9hvmwB+BFwK7AA+m5m1Va5PZpauLrX1UyUPqT7nn39+GVPSnJKvqlpxShpS0ouSqFRMORYr2U7JRkrqUzXrlOxVjV9JXuvWrStjF1xwQRlT46+cdqM6MZVUqaRn5SKs5lE5IKt6gep5nskwr/zfA66b0XYLcH9mXg7c3//bGPMOYtbkz8wHgJkvNdcDd/Vv3wV86hSPyxgzz4z6mf/CzNwD0P9db21rjDktmfcFv4jYFBFbImKL+gxjjOmWUZN/X0RMAvR/l7WNMnNzZk5l5pRaxDLGdMuoyX8PcEP/9g3AL07NcIwxXTGM1PdD4CPA8ojYBXwV+Brw44i4EdgJfGaYkylXnypKWW25pLaZUu8ylINQyYBVocXly5eXfUZxc4GW2EbZqkk9rt27d5cxJSsqF1v1fK5YsaLso7YhUxLsoUOHylgl26lim8pduGzZsjKm5Dz1kbe6HpUc+cwzzwxsV+7Ntx1/tjtk5ueL0MeGPosx5rTD3/AzplGc/MY0ipPfmEZx8hvTKE5+Yxql0wKeEVFKLEoSq+RBJYcplESoHGKVq0+50ZQMqMavJDYl9VWSmJL6Dh48WMbUfKxcWX+ru5L6lHy1evXqMqZkNPXYKtQ1oMZ43nnnlTHlBnzuuefKWCVVKmdqtWegch3OxK/8xjSKk9+YRnHyG9MoTn5jGsXJb0yjOPmNaZTOpb5KRhnFqaaKFSopRzm61DGr2OLFi8s+ygWm3GhKYlNzNco8KgfkqIUuq6KUO3fuLPtceeWVZUzJio8++mgZq/Y1VI9ZyWWqkKgq4KnOV12rqvirkj6Hxa/8xjSKk9+YRnHyG9MoTn5jGsXJb0yjdLrar1A1zqqVTWUGUltJqVVUtbpdGXvUOFS9QFVvTdX3U4pEdUz1uFTNOjUO9bj37t07sH3btm1ln6uvvrqMVXM/2ziqlXulpiilSF2nyiCljlmZhVSNxEo9UM/X2+479D2NMe8qnPzGNIqT35hGcfIb0yhOfmMaxclvTKMMs13XncAngf2Z+cF+223AF4ETrolbM/PeIY5VbnmlzBSV5KEMNcrg8vLLL5cxNY5KLlMyzp49e8qYMvaMKvVVEpaSw5RBR0mEaou1ShJTEqyaDyWZqsdWXTvKGKOuKyUTq+da9Vu1atXAdnUNV/N7MsauYV75vwdcN6D9W5m5of8za+IbY04vZk3+zHwAqEvJGmPekczlM/9NEfFYRNwZEXWNYWPMacmoyf8dYD2wAdgDfKO6Y0RsiogtEbFFfaXSGNMtIyV/Zu7LzGOZeRz4LlB+KTszN2fmVGZOVYt9xpjuGSn5I2Jy2p+fBp44NcMxxnTFMFLfD4GPAMsjYhfwVeAjEbEBSGAH8KVhTrZw4UIuvvjigTElbVUxVYtPOeZefPHFkWIrVqwY2P7888+XfdS2W9XWWqClKOUsq2RAVUNOyUNKVlyyZEkZq+aqagfttFMyoHpHWY1fXR9q2zB1LnVMJUdW51MSciVXn4zUN2vyZ+bnBzTfMfQZjDGnJf6GnzGN4uQ3plGc/MY0ipPfmEZx8hvTKJ0W8Fy8eDEbNmwYGNu9e3fZr5JQlKyhHGfVFk4A27dvL2OV+03JP0rOU04vJQ0p11kll6lxKIlK9VPzX83VmjVryj5qHpXspWTRI0eODGxXLkG1JZcqdjoxMVHG1FxVUraSZ6utwU5mGy+/8hvTKE5+YxrFyW9Mozj5jWkUJ78xjeLkN6ZROpX6zjnnHC6//PKBsZ07d5b9KilKOQGV5KGce08++WQZq2Seq666quyjUHJTJeUAnH9+XTipkpT2799f9lEuwVEKmkK9t97y5cvLPtWedaClMiUDVrKocgmquVKuRLWfoCpcWjk/lXRYuWOV03UmfuU3plGc/MY0ipPfmEZx8hvTKE5+Yxql09X+BQsWlCu6o2yhVW3FBNokos6lDEZVPbhly5aVfZSJSNX+UwYYpQRUxh61Iq7q9CmDlFqBr54ztYI9OTlZxpTBSNX+q86n+qjHrGoyjjpXlfqkru9K8bGxxxgzK05+YxrFyW9Mozj5jWkUJ78xjeLkN6ZRhtmuay3wfWAVcBzYnJnfjogJ4EfApfS27PpsZtaOGXo166q6dUrKqWQqVQNPmU5GNYlUdelUvb1RpT4lvylDUyUPqflQEpXqp6SoKqaMPevXry9jSmJTc1wZe5RBR9VIVEYnZdRS/So5Ul2nVb6onJjJMK/8R4GvZOYVwIeBL0fElcAtwP2ZeTlwf/9vY8w7hFmTPzP3ZObD/dtHgK3ARcD1wF39u90FfGq+BmmMOfWc1Gf+iLgU+BDwIHBhZu6B3j8IYOWpHpwxZv4YOvkjYinwU+DmzDx8Ev02RcSWiNiiPuMaY7plqOSPiLPoJf4PMvNn/eZ9ETHZj08CA8ufZObmzJzKzCm1GYIxpltmTf7oLQXfAWzNzG9OC90D3NC/fQPwi1M/PGPMfDGMq+8a4AvA4xHxSL/tVuBrwI8j4kZgJ/CZYU6oJI+TRdVFU+dRUs773ve+MlZJfWqLL/VRR0llynVWbUGlUFs/qeOpunqqlmD13Kg+q1atKmPPPPNMGXvppZfKWPW41TWwbt26MrZ3794ytmvXrjKm6iRW8qeSvw8fHvzJW0nVM5k1+TPz10AlBH9s6DMZY04r/A0/YxrFyW9Mozj5jWkUJ78xjeLkN6ZROi3gmZmlFKFkr8rdpOQw5W5S2ypdcsklZawq1KkKgqqCipV0CLBv374ypqS5Si6rtjwDLUMpp51ynVWS2JIlS8o+qjCpGqOa/1GuN3XtqPErObWS5qB2cKrH1ZWrzxjzLsTJb0yjOPmNaRQnvzGN4uQ3plGc/MY0ymkj9SlJrJJllFyj3E3K8af2+KukqIULF5Z9lKy4cmVd/Eg5xFQxy4suumhge7UfHOhioUrqU067SppTBUGfffbZMqYes5LfqutqPvbjU8VJ1fVdFWTdv39giQygvoYt9RljZsXJb0yjOPmNaRQnvzGN4uQ3plE6X+1XK8sV1QqrOpaq76e2u1KrspU5Zu3atSONozIKga79p2KVGUSZgZShRtXwU+Oo5lGZgR566KEyprbCUrX/qvlQprBDhw6VMWXeUaqPqg1Zbff2yCOPDGyHevxqfmfiV35jGsXJb0yjOPmNaRQnvzGN4uQ3plGc/MY0yqxSX0SsBb4PrAKOA5sz89sRcRvwReBA/663Zua96ljHjx8vpRdltqnqlakaZ2oLJ1U7T8k1VR02VV9OGVmUSeT9739/GduxY0cZqww8an4nJyfLmDKKKONJ1U8ZjJRRSBm1RjH2qOtD1TtUMrGa48pwBbWcqqS+U1HDbxid/yjwlcx8OCLOBR6KiPv6sW9l5j8MfTZjzGnDMHv17QH29G8fiYitQP1vzBjzjuCkPvNHxKXAh4AH+003RcRjEXFnRNTbrxpjTjuGTv6IWAr8FLg5Mw8D3wHWAxvovTP4RtFvU0RsiYgt6vOeMaZbhkr+iDiLXuL/IDN/BpCZ+zLzWGYeB74LXD2ob2ZuzsypzJxSe7MbY7pl1uSP3pL0HcDWzPzmtPbpS8SfBp449cMzxswXw6z2XwN8AXg8Ik5oD7cCn4+IDUACO4AvzXagY8eOlY4pJZdV0ouS+pTsomrnLV68uIxVW2iprZiUC0yhnILKTbd9+/aB7UqiGsUVB9oBWVGND7TTbsWKFWVMyanVdTDqtaPqRlbuPNDyclXnUR2vesxKPn7bMWa7Q2b+Ghh0RKnpG2NOb/wNP2MaxclvTKM4+Y1pFCe/MY3i5DemUTot4KmkPlV4sHIqKalJbZOl5DwlyVTyipKGVOFM5fRSxT2VfFg57dRcqeMpCVbNVdWvkktBPy/q+VRjrJ4bJYnt3r27jKl+V1xxRRlTY6y2B1PXwOrVqwe2KylyJn7lN6ZRnPzGNIqT35hGcfIb0yhOfmMaxclvTKN0LvVVDjjlzKqKalZyB2j5TRWDVAUQqz3tDh48WPZRDjwl9SnpU7m9JiYmBraropSqmGUlQ6lzQS05jVJ8FLTspeSt6rpSsqIahyoWquRItddgJX+q+hfVdWqpzxgzK05+YxrFyW9Mozj5jWkUJ78xjeLkN6ZROpX6MrOU2ZSzrJIvlKyhimpW+5zN1q/aU00Vx3zllVdGio0q9V1wwQUD29UY1bkOHDhQxpScWjn+VGFSJZmqmHLMVRKhciQqGVDJxErWPZnCmieonkuFksxn4ld+YxrFyW9Mozj5jWkUJ78xjeLkN6ZRZl0ajIhzgAeAs/v3/0lmfjUi1gF3AxPAw8AXMrNeRqe34lytVKvV0GprpVFr8akVfWXAqKiMR6DNL0p1qGodgl5xrhQQtd2Vmiu1rZUyQVXPpzLoKBVj7969ZUwpAdX4K+UGYM2aNWVMMarCVK3QK2PPokWLBrYrVedt9x3iPq8DH83Mq+htx31dRHwY+Drwrcy8HHgRuHHosxpjxs6syZ89TrwcntX/SeCjwE/67XcBn5qXERpj5oWh3iNExIL+Dr37gfuA3wOHMvPEtzx2AbU53Rhz2jFU8mfmsczcAKwBrgYGFSgf+AEwIjZFxJaI2KI+ExljuuWkVvsz8xDwX8CHgWURcWKlYg0wcKeDzNycmVOZOaUqnRhjumXW5I+IFRGxrH97EfCnwFbgV8Bf9O92A/CL+RqkMebUM4wLYBK4KyIW0Ptn8ePM/PeIeBK4OyL+Dvhv4I5hTqjkoYpK2lKSlzI4KEOQOma15ZWSypSEqWQZVTtPbQFWmW1GqXMH2jSjnsvKLFRJVKDlN/WY1Rir2oXKlLRy5cqTPh7Aq6++WsZGqRs5iuysTFozmTX5M/Mx4EMD2rfT+/xvjHkH4m/4GdMoTn5jGsXJb0yjOPmNaRQnvzGNEqNIbyOfLOIA8Gz/z+VAbcfqDo/jrXgcb+WdNo5LMrO2cE6j0+R/y4kjtmTm1FhO7nF4HB6H3/Yb0ypOfmMaZZzJv3mM556Ox/FWPI638q4dx9g+8xtjxovf9hvTKGNJ/oi4LiL+JyKejohbxjGG/jh2RMTjEfFIRGzp8Lx3RsT+iHhiWttERNwXEdv6v+vqjfM7jtsi4v/6c/JIRHyig3GsjYhfRcTWiPhdRPxVv73TORHj6HROIuKciPhNRDzaH8ft/fZ1EfFgfz5+FBF15dhhyMxOf4AF9MqAXQYsBB4Frux6HP2x7ACWj+G81wIbgSemtf09cEv/9i3A18c0jtuAv+54PiaBjf3b5wL/C1zZ9ZyIcXQ6J0AAS/u3zwIepFdA58fA5/rt/wT85VzOM45X/quBpzNze/ZKfd8NXD+GcYyNzHwAmGnYv55eIVToqCBqMY7Oycw9mflw//YResViLqLjORHj6JTsMe9Fc8eR/BcBz037e5zFPxP4ZUQ8FBGbxjSGE1yYmXugdxECdUWJ+eemiHis/7Fg3j9+TCciLqVXP+JBxjgnM8YBHc9JF0Vzx5H8g0rbjEtyuCYzNwJ/Dnw5Iq4d0zhOJ74DrKe3R8Me4BtdnTgilgI/BW7OzLFVex0wjs7nJOdQNHdYxpH8u4Dpm7SXxT/nm8zc3f+9H/g5461MtC8iJgH6v/ePYxCZua9/4R0HvktHcxIRZ9FLuB9k5s/6zZ3PyaBxjGtO+uc+6aK5wzKO5P8tcHl/5XIh8Dngnq4HERFLIuLcE7eBjwNP6F7zyj30CqHCGAuinki2Pp+mgzmJXqHDO4CtmfnNaaFO56QaR9dz0lnR3K5WMGesZn6C3krq74G/GdMYLqOnNDwK/K7LcQA/pPf28U1674RuBC4A7ge29X9PjGkc/wI8DjxGL/kmOxjHH9N7C/sY8Ej/5xNdz4kYR6dzAvwRvaK4j9H7R/O3067Z3wBPA/8GnD2X8/gbfsY0ir/hZ0yjOPmNaRQnvzGN4uQ3plGc/MY0ipPfmEZx8hvTKE5+Yxrl/wHo7ZuXD1PlgwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "output = conv(img.unsqueeze(0))\n", "plt.imshow(output[0, 0].detach(), cmap='gray')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1, 3, 32, 32]), torch.Size([1, 3, 16, 16]))" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pool = nn.MaxPool2d(2)\n", "output = pool(img.unsqueeze(0))\n", "\n", "img.unsqueeze(0).shape, output.shape" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " nn.Conv2d(3, 16, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " nn.Conv2d(16, 8, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " # ...\n", " )" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " nn.Conv2d(3, 16, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " nn.Conv2d(16, 8, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " # ... <1>\n", " nn.Linear(8 * 8 * 8, 32),\n", " nn.Tanh(),\n", " nn.Linear(32, 2))" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(18090, [432, 16, 1152, 8, 16384, 32, 64, 2])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "numel_list = [p.numel() for p in model.parameters()]\n", "sum(numel_list), numel_list" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "tags": [ "raises-exception" ] }, "outputs": [ { "ename": "RuntimeError", "evalue": "size mismatch, m1: [64 x 8], m2: [512 x 32] at c:\\a\\w\\1\\s\\tmp_conda_3.6_091443\\conda\\conda-bld\\pytorch_1544087948354\\work\\aten\\src\\th\\generic/THTensorMath.cpp:940", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[1;32m~\\Miniconda3\\envs\\book\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 487\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 488\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 489\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 490\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 491\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Miniconda3\\envs\\book\\lib\\site-packages\\torch\\nn\\modules\\container.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 90\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 91\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_modules\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 92\u001b[1;33m \u001b[0minput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 93\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 94\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Miniconda3\\envs\\book\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 487\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 488\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 489\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 490\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 491\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Miniconda3\\envs\\book\\lib\\site-packages\\torch\\nn\\modules\\linear.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 65\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mweak_script_method\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 66\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 67\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 68\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 69\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Miniconda3\\envs\\book\\lib\\site-packages\\torch\\nn\\functional.py\u001b[0m in \u001b[0;36mlinear\u001b[1;34m(input, weight, bias)\u001b[0m\n\u001b[0;32m 1352\u001b[0m \u001b[0mret\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0maddmm\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_unwrap_optional\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mt\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1353\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1354\u001b[1;33m \u001b[0moutput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mt\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1355\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mbias\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1356\u001b[0m \u001b[0moutput\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_unwrap_optional\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mRuntimeError\u001b[0m: size mismatch, m1: [64 x 8], m2: [512 x 32] at c:\\a\\w\\1\\s\\tmp_conda_3.6_091443\\conda\\conda-bld\\pytorch_1544087948354\\work\\aten\\src\\th\\generic/THTensorMath.cpp:940" ] } ], "source": [ "model(img.unsqueeze(0))" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)\n", " self.act1 = nn.Tanh()\n", " self.pool1 = nn.MaxPool2d(2)\n", " self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)\n", " self.act2 = nn.Tanh()\n", " self.pool2 = nn.MaxPool2d(2)\n", " self.fc1 = nn.Linear(8 * 8 * 8, 32)\n", " self.act3 = nn.Tanh()\n", " self.fc2 = nn.Linear(32, 2)\n", "\n", " def forward(self, x):\n", " out = self.pool1(self.act1(self.conv1(x)))\n", " out = self.pool2(self.act2(self.conv2(out)))\n", " out = out.view(-1, 8 * 8 * 8) # <1>\n", " out = self.act3(self.fc1(out))\n", " out = self.fc2(out)\n", " return out" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(18090, [432, 16, 1152, 8, 16384, 32, 64, 2])" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = Net()\n", "\n", "numel_list = [p.numel() for p in model.parameters()]\n", "sum(numel_list), numel_list" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)\n", " self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)\n", " self.fc1 = nn.Linear(8 * 8 * 8, 32)\n", " self.fc2 = nn.Linear(32, 2)\n", " \n", " def forward(self, x):\n", " out = F.max_pool2d(torch.tanh(self.conv1(x)), 2)\n", " out = F.max_pool2d(torch.tanh(self.conv2(out)), 2)\n", " out = out.view(-1, 8 * 8 * 8)\n", " out = torch.tanh(self.fc1(out))\n", " out = self.fc2(out)\n", " return out" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.0157, 0.1143]], grad_fn=)" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = Net()\n", "model(img.unsqueeze(0))" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "import datetime\n", "\n", "def training_loop(n_epochs, optimizer, model, loss_fn, train_loader):\n", " for epoch in range(1, n_epochs + 1):\n", " loss_train = 0.0\n", " for imgs, labels in train_loader:\n", " outputs = model(imgs)\n", " loss = loss_fn(outputs, labels)\n", " \n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " \n", " loss_train += loss.item()\n", " \n", " if epoch == 1 or epoch % 10 == 0:\n", " print('{} Epoch {}, Training loss {}'.format(\n", " datetime.datetime.now(), epoch, float(loss_train)))" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2018-12-30 15:53:41.593150 Epoch 1, Training loss 88.4665624499321\n", "2018-12-30 15:55:43.742503 Epoch 10, Training loss 51.45866137742996\n", "2018-12-30 15:57:59.663029 Epoch 20, Training loss 47.65159119665623\n", "2018-12-30 16:00:40.054150 Epoch 30, Training loss 44.35375756025314\n", "2018-12-30 16:02:54.528544 Epoch 40, Training loss 40.991573348641396\n", "2018-12-30 16:05:07.754276 Epoch 50, Training loss 37.843220107257366\n", "2018-12-30 16:07:21.260259 Epoch 60, Training loss 34.528036408126354\n", "2018-12-30 16:09:34.664515 Epoch 70, Training loss 31.973807267844677\n", "2018-12-30 16:11:47.704744 Epoch 80, Training loss 29.7293216958642\n", "2018-12-30 16:14:00.862659 Epoch 90, Training loss 27.13489008694887\n", "2018-12-30 16:16:13.841052 Epoch 100, Training loss 25.345759883522987\n" ] } ], "source": [ "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)\n", "\n", "model = Net()\n", "optimizer = optim.SGD(model.parameters(), lr=1e-2)\n", "loss_fn = nn.CrossEntropyLoss()\n", "\n", "training_loop(\n", " n_epochs = 100, \n", " optimizer = optimizer,\n", " model = model,\n", " loss_fn = loss_fn,\n", " train_loader = train_loader,\n", ")" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.930400\n", "Accuracy: 0.930400\n" ] } ], "source": [ "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=False)\n", "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)\n", "\n", "for loader in [train_loader, val_loader]:\n", " correct = 0\n", " total = 0\n", "\n", " with torch.no_grad():\n", " for imgs, labels in train_loader:\n", " outputs = model(imgs)\n", " _, predicted = torch.max(outputs, dim=1) # <1>\n", " total += labels.shape[0]\n", " correct += int((predicted == labels).sum())\n", "\n", " print(\"Accuracy: %f\" % (correct / total))" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "torch.save(model.state_dict(), data_path + 'birds_vs_airplanes.pt')" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "loaded_model = Net() # <1>\n", "loaded_model.load_state_dict(torch.load(data_path + 'birds_vs_airplanes.pt'))" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.891500\n" ] } ], "source": [ "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)\n", "\n", "correct = 0\n", "total = 0\n", "\n", "with torch.no_grad():\n", " for imgs, labels in val_loader:\n", " outputs = model(imgs)\n", " _, predicted = torch.max(outputs, dim=1)\n", " total += labels.shape[0]\n", " correct += int((predicted == labels).sum())\n", " \n", "print(\"Accuracy: %f\" % (correct / total))" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "18090" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)\n", " self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)\n", " self.fc1 = nn.Linear(8 * 8 * 8, 32)\n", " self.fc2 = nn.Linear(32, 2)\n", " \n", " def forward(self, x):\n", " out = F.max_pool2d(torch.relu(self.conv1(x)), 2)\n", " out = F.max_pool2d(torch.relu(self.conv2(out)), 2)\n", " out = out.view(-1, 8 * 8 * 8)\n", " out = torch.tanh(self.fc1(out))\n", " out = self.fc2(out)\n", " return out\n", " \n", "model = Net()\n", "sum([p.numel() for p in model.parameters()])" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "ename": "RuntimeError", "evalue": "size mismatch, m1: [64 x 8], m2: [512 x 32] at c:\\a\\w\\1\\s\\tmp_conda_3.6_091443\\conda\\conda-bld\\pytorch_1544087948354\\work\\aten\\src\\th\\generic/THTensorMath.cpp:940", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 10\u001b[0m nn.Linear(32, 2))\n\u001b[0;32m 11\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 12\u001b[1;33m \u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[1;32m~\\Miniconda3\\envs\\book\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 487\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 488\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 489\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 490\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 491\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Miniconda3\\envs\\book\\lib\\site-packages\\torch\\nn\\modules\\container.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 90\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 91\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_modules\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 92\u001b[1;33m \u001b[0minput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 93\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 94\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Miniconda3\\envs\\book\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 487\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 488\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 489\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 490\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 491\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Miniconda3\\envs\\book\\lib\\site-packages\\torch\\nn\\modules\\linear.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 65\u001b[0m \u001b[1;33m@\u001b[0m\u001b[0mweak_script_method\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 66\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 67\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 68\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 69\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Miniconda3\\envs\\book\\lib\\site-packages\\torch\\nn\\functional.py\u001b[0m in \u001b[0;36mlinear\u001b[1;34m(input, weight, bias)\u001b[0m\n\u001b[0;32m 1352\u001b[0m \u001b[0mret\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0maddmm\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_unwrap_optional\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mt\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1353\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1354\u001b[1;33m \u001b[0moutput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mt\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1355\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mbias\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1356\u001b[0m \u001b[0moutput\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_unwrap_optional\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mRuntimeError\u001b[0m: size mismatch, m1: [64 x 8], m2: [512 x 32] at c:\\a\\w\\1\\s\\tmp_conda_3.6_091443\\conda\\conda-bld\\pytorch_1544087948354\\work\\aten\\src\\th\\generic/THTensorMath.cpp:940" ] } ], "source": [ "model = nn.Sequential(\n", " nn.Conv2d(3, 16, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " nn.Conv2d(16, 8, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " nn.Linear(8*8*8, 32),\n", " nn.Tanh(),\n", " nn.Linear(32, 2))\n", "\n", "model(img.unsqueeze(0))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" } }, "nbformat": 4, "nbformat_minor": 2 }