{ "cells": [ { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n" ] } ], "source": [ "import torch\n", "from torchvision import datasets\n", "\n", "cifar10 = datasets.CIFAR10('data', train=True, download=True)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n" ] } ], "source": [ "cifar10_val = datasets.CIFAR10('data', train=False, download=True)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torchvision.datasets.cifar.CIFAR10" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(cifar10)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "isinstance(cifar10, torch.utils.data.Dataset)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "50000" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(cifar10)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "img, label = cifar10[99]" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "print(img)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from matplotlib import pyplot as plt\n", "\n", "plt.imshow(img)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "label" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['CenterCrop',\n", " 'ColorJitter',\n", " 'Compose',\n", " 'FiveCrop',\n", " 'Grayscale',\n", " 'Lambda',\n", " 'LinearTransformation',\n", " 'Normalize',\n", " 'Pad',\n", " 'RandomAffine',\n", " 'RandomApply',\n", " 'RandomChoice',\n", " 'RandomCrop',\n", " 'RandomGrayscale',\n", " 'RandomHorizontalFlip',\n", " 'RandomOrder',\n", " 'RandomResizedCrop',\n", " 'RandomRotation',\n", " 'RandomSizedCrop',\n", " 'RandomVerticalFlip',\n", " 'Resize',\n", " 'Scale',\n", " 'TenCrop',\n", " 'ToPILImage',\n", " 'ToTensor',\n", " '__builtins__',\n", " '__cached__',\n", " '__doc__',\n", " '__file__',\n", " '__loader__',\n", " '__name__',\n", " '__package__',\n", " '__path__',\n", " '__spec__',\n", " 'functional',\n", " 'transforms']" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from torchvision import transforms\n", "dir(transforms)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([3, 32, 32])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from torchvision import transforms\n", "\n", "to_tensor = transforms.ToTensor()\n", "\n", "img, label = cifar10[99]\n", "\n", "img_t = to_tensor(img)\n", "\n", "img_t.shape" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n" ] } ], "source": [ "cifar10 = datasets.CIFAR10('data', train=True, download=True,\n", " transform=transforms.ToTensor())" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Tensor" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img, _ = cifar10[99]\n", "type(img)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([3, 32, 32])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img.shape" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.float32" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img.dtype" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.), tensor(1.))" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img.min(), img.max()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAH0hJREFUeJztnXuMXdd13r913/PkcIavEUWJokiLlGS9SqtK5Rqy0jqKG0Q2mih2mkAIDDMoYqBGnT8EF6gdoH8kRS3XTQsHdKRECRy/bVioDceqokQx/NDLFEmJkkyRNJ8aPuY9c9939Y971VL0/vZcvu5Q2d8PGMzMXnefs+6+Z51zz/7OWtvcHUKI9MgstwNCiOVBwS9Eoij4hUgUBb8QiaLgFyJRFPxCJIqCX4hEUfALkSgKfiESJXcxnc3sPgCfA5AF8Ofu/sex1+fzOS+V8kFbq9Wk/bzVYg7QPpnoaY33i9ncw35E3EDsCUqz7AV4AVhkh9lceHyz2XA7AJQX5yN7I2MPoK/UR20D/YPB9sXFBdqnXi9TWybynvNZfhhncsVge/9guB0AmpFjsVzj/udz/KDL5yKfdSZ8jOSyfHuLi+E+U1NlLCzUYofP/99+Ny8KYe0j938B+NcAjgJ41swed/eXWZ9SKY87tm8O2uZnJ+m+GrVqsD2b54PT3x8J4lbkbWe4rVYN+5GPbK5Zr1FbPjdEbRYJ/3yBH7grx9YG20eG19E+e/b8gNrg3P9tN9xMbXfd+i+C7c+/+Azt88bxvdTWX+Qnr6uGVlPbwKrrgu233L2J9pmtTlPbvoPc/3Vr+ee5dozbiv3hk81I5AS1e1cj2P4//vSHtM+5XMzX/jsB7Hf3A+5eA/BlAPdfxPaEED3kYoJ/PYAjZ/1/tNMmhHgbcDH3/KHvpb9wI2JmOwDsAIBi5KubEKK3XMyV/yiADWf9fzWA4+e+yN13uvt2d9+ez/NJDyFEb7mY4H8WwBYzu87MCgA+BODxS+OWEOJyc8Ff+929YWYfA/C3aEt9j7r7S9FO5jAjM+aRLwWZQinYnitGzl0RscOc76yyEPYPAFpE9orNvlsuIvXlwjO2bQrUMjU7Q22np6aC7eXyLu5HRM4b6AuPPQBMTJ2htid+9HfB9pZxGW22VqG2vogfsxXeb2Q4LDn2FcOqEwBsGOcz89Mzv/Dl9v8xOsb9GBrmx9xiNSwfzi/yY6DUH76FzmS6UvkAXKTO7+7fBfDdi9mGEGJ50BN+QiSKgl+IRFHwC5EoCn4hEkXBL0SiXNRs//niDtSbYemrb2iA9quQ3JJWk0srzQZ/mrBa4XLe4GBYGgIAr8+G98WyDgG0jJ9fi7mILJPhmXb5Epe9anPhzLhiictGMC45uvHEnuMnD1NbnmQ7VRe51FeILCHRV+B+VDN8m7VD4WShxdox2qdUXEltV224mtoqczSnDRNz3MdsIXwczDnPIDw5GT6G642YfPxWdOUXIlEU/EIkioJfiERR8AuRKAp+IRKlp7P9GQOKJBlnZnaR9jMPz1THkk5iiSAL5fOv0wcA5Vp4Orp/MDKT3uSzr+VFXrOuXuF+5Ep1ajML98tFash57BpA1BkA6MtzRaVeDx9amSb3o+VcvVmMJFz19fFEnPJiONFp4hTf1/ziEWobHr2X2kr9vFTabGWC2irl8Bg3wRWO0zPh8Wg0+XFzLrryC5EoCn4hEkXBL0SiKPiFSBQFvxCJouAXIlF6KvU1Wy0skMSTOldeMLIiLNtVylwebEYSHGZmuIQyOxtO3gGAMbLqyiBXFTEzG5H65rmMli/wj2ZxIZKIQ6RKd36er5Z50kmrHqlBmOWyUjEf3qaV+PYa3I22Tkzoz3JbmaxcdWqKJ80Ui5F6gdO8buEUkd8A4ORpbhseDn82sRyd8kL4fXmz+xp+uvILkSgKfiESRcEvRKIo+IVIFAW/EImi4BciUS5K6jOzQwDmADQBNNx9e+z1GTMUSuGsrlKJZ4jNk+Wp6hFtqFbjb61a5fXxRse4H8PD4faJ43x7tRbPwCuSsQCASMIccpGxqiyGpZ5KhftRKkbGKpJZ5i2uRbHkvXykpmGzzmWqTET6LJd4v+mFsP+NZqSm3ko+vicmjlJbrcWzNCsRLbtSDkuLzUiGXrka9j/W51wuhc7/Xnc/fQm2I4ToIfraL0SiXGzwO4Dvm9nzZrbjUjgkhOgNF/u1/253P25mawA8YWavuPvTZ7+gc1LYAQDFYmQdbiFET7moK7+7H+/8PgngWwDuDLxmp7tvd/ft+byCX4grhQsOfjMbMLOhN/8G8D4A4eVRhBBXHBfztX8tgG+Z2Zvb+Rt3/16sQ6sFLM6HpYhMlsskOeJlNs8LZ3pE8ti8bYTahgb4kMyeDstlzZWRrLJIxlwmUlSzRqQcABgZ5f1WrgrLVPOz3MdqmY/V6Fq+jFrRuCQ2Ox+W2OqILVvFt1eOyLqLLT4eDbKkW7PMJcw54/uq1ri8uXJ0lNoidVCx6GGpuJjjx3ezNRdsd4+lRr6VCw5+dz8A4NYL7S+EWF4k9QmRKAp+IRJFwS9Eoij4hUgUBb8QidLbtfoywHB/+HyTjWRtLcyFZZl8LlIAs8RlkhYp6ggAdePZb14IS2JjJNsPAI4f4ftisicANJ37kSvxsVo5HJbLmpH1CQuR7fXHxrHF/W+RbLqRVbw4ZpnX1MTcDM+KmzwdzvoEgMH+sP850g4AzRY/rupVbpuZCctvQDyTtETWlcyP8M/sqvWrw30KvMDouejKL0SiKPiFSBQFvxCJouAXIlEU/EIkSk9n+x1ArRWewZyb4LOhK0fD0+mtJl+uq26RGex+vnTSfGQ2t1kLz2CXCnzmeGiI21YM8ISUyWk+kz4zGVEJqmEfc+DvazDiY2WRj1WN7AsAhkeKwfYCy9ICUIyoJmcm+Mx33yAfx4Vq+BgpRhSOauwYWOQqTH+Tj2OuGEv+Co+xR5KgykQaqUcSj85FV34hEkXBL0SiKPiFSBQFvxCJouAXIlEU/EIkSk+lvlazhbn5sETRbHLZaIFIIbPTXIYq5rkkk83yWnHZTGTJKNJeq3FJJpfntr4Cl5TKdX5edo/JkWEZsBV5z5VJnjRTyPJDJJ/t4354WGKLjX2tzN9zxiJLcs3wY2flWFhyLFf5sVOt8fEdG4klJnGZbbHKbS1yiMxMcT/G164Mtnv3q3Xpyi9Eqij4hUgUBb8QiaLgFyJRFPxCJIqCX4hEWVLqM7NHAfwagJPufnOnbRTAVwBsBHAIwAPuPrXUtjKZDIZKYXloYo4v17VYng22u/NsLm9Glnea4+e867YNUluFlIqbnueykUfq3FUb3FZawd/bwGBELpsJb3P6DPexleWSUsu4ROXgtv6R8Bi3MlyWW7G6n9quK3LbzDSXKht14mNk/ayhFfz4GI7U1UOLh9Ph4zwDdXQ0vCTacCTbslYLx4ufh9bXzZX/LwHcd07bQwCedPctAJ7s/C+EeBuxZPC7+9MAJs9pvh/AY52/HwPwgUvslxDiMnOh9/xr3f0EAHR+r7l0LgkhesFlf7zXzHYA2AEAhQK/jxVC9JYLvfJPmNk4AHR+n2QvdPed7r7d3bfn8wp+Ia4ULjT4HwfwYOfvBwF8+9K4I4ToFd1IfV8CcA+AVWZ2FMCnAPwxgK+a2UcAHAbwm93sLJMx9JOliTKRbwUZsnxSiSdYYdVably1lr/tRpNLYrPzYfmwxlUcNOpcchy9imfFjYzybVarfJtzJAOyEZGAvMqvAes2c7mpXuF+ZC1sy+Z4H2S4dJgrcNvAIP88T50MS4sDxUi2YqTY5sw892NogI/VVQNcQp4iUvFwRO4tlcK2TCQr9VyWDH53/zAx/XLXexFCXHHoCT8hEkXBL0SiKPiFSBQFvxCJouAXIlF6WsCzWq3jtQNHw0bjmWqlvvA5avU4l8rGxri0lYmsW9eo8SEZGAzLKH1F7vvhn3NpyyLn3vk5LilNn+G2Rp28t0h2XnGQZ8w1Imu/ZXORa0czLLVOT3EpNZ/jmmk+cqhaM5LdSaTWlvFjIKaWtSKFOBeKfDw2ruXHSGY2nJXYasQKtYbfs3v3Up+u/EIkioJfiERR8AuRKAp+IRJFwS9Eoij4hUiUnkp97oZWKyx51Gt8bb2x1eH11jZtDRc+BICpE1xSmpzktsHwEmgAgOGR8HBNneIS1dhVXOLpH+JSztQpLtnUI2sD3nndO4LtW1bzNMGv7X2W2pDjMtqBffx9rx4PZ7h5RGJrNPi1qBrJjmxGbLlSWPId3xQp1DrLZeLKCV5odqDObVOVSJFREoa1RR4ThVL4+PCIjH0uuvILkSgKfiESRcEvRKIo+IVIFAW/EInS09n+Qi6LDStXBG37j03QfgukxtlLe2jRYNQrfMa2r8Rneo8c5DPYI2Phme9Glc/KtiysVADAxDHer2+Az7JXFnlyyR3rtgTb33fXu2ifmSpfQmvvwSPUdu+2bdT24rHXg+3Wz5WWRpmP1VXrx6jt0Ov82FnbHz7e1hW4CjOfjXwuwzwJ6vSZaWrL9/EktEY9PCZDg7wm4KiFbTlTYo8QYgkU/EIkioJfiERR8AuRKAp+IRJFwS9EonSzXNejAH4NwEl3v7nT9mkAHwVwqvOyT7r7d5fcWTaL0ZXDQdvK8gztNzURTlbwFpfDhiI1/BYWFqgtR+oFAkBlPry/Mt8cKk1uXODKENasHaK2eoXLRvvLc8H2/h+/QPu87xou2W3Jr6K2bdduorYdf/5KsH3y1Dzt867bb6W2jRv5KvAVIgUDwMxkWLY7NcGTwqol/sHUiSwHAPU8zwpbs4777/MniIF2Qa40Emw3e4N3Oodurvx/CeC+QPtn3f22zs+SgS+EuLJYMvjd/WkAkz3wRQjRQy7mnv9jZrbbzB41s0gWvBDiSuRCg//zAK4HcBuAEwA+w15oZjvM7Dkze65W54+lCiF6ywUFv7tPuHvT3VsAvgDgzshrd7r7dnffXsj3NJVACBHhgoLfzMbP+veDAPZeGneEEL2iG6nvSwDuAbDKzI4C+BSAe8zsNrTFiEMAfr+bnTW9ifnGbNA2OByWAAFgfj4sXy3McNmlVOQZUStXcYnw5Cme4bZyNGyrV7kmc2qSb68VyTycPcPfW8bCS2EBwDv/5e8E2+ffOEb7zL8RzsADgNn5KWo7fYRv8xO/9YFg+9//dDftM7D+OmpbN7qa2spbuUx87PC+YPvkMSKvAagM8M/T8vzYqc/xz/q1I1yCmy2Hx3jtSDgjEQBGNl8TbM/mD9A+57Jk8Lv7hwPNj3S9ByHEFYme8BMiURT8QiSKgl+IRFHwC5EoCn4hEqWnT91Uaw28fjCcJlBv8iWX+gfCst2a9bwIY6XMnyacXeASW+w5pINHw/1WDfFz6E1rePbYAnjGXL3OZaNikReRvPX2fxZsb5Z5xlxrz3PU9uR3uER1/NjL1Pah3/7tYPvcJM/q+8aL4UxAAHjv791GbbEPrUZk2KuNL5+Vf/lFahsq8mMuZ9w2bdzHmVJY0msUuKRbnzodbPdm90/R6sovRKIo+IVIFAW/EImi4BciURT8QiSKgl+IRDH3SJXAS0whn/e1q8JFf/J5Lr8VSuH1x+rG5bDmAreNbeISSq7GC2f+ylw4o+uBU8dpn8fXbKS27w3xTEZr8qy+GldF8Uv3/HKw/d+9917ap3FgP7U9teuH1HbiJH/f777x5mD76RmeJdjKRrItS3ysqmf4Wn1DmzcG229o8OPt1/t5sc08+OB7ZD0+r0TWczwaXnOyfJxnHh5+/afB9t969QheWqx0tWCfrvxCJIqCX4hEUfALkSgKfiESRcEvRKL0NLEnm3MMj4RnS0eG+Sz7sVPhJIbKHJ/UnJnntu2jo9T2qetvpLab3rkh2J45yWewDx7gtU2/Hln6ySKJThnn7+2HfxtePOn2dXx87Y3D1Hbzjeuo7dcfCFV4azOH8Mz9OPh73vk//5Ta1mzeSm0rSD07ABj38Az8Lf28xqNv5cuQ1bbxBKnMO26iNuzeRU2tJ74fbM+fPEL7bK2FE3hK56He6covRKIo+IVIFAW/EImi4BciURT8QiSKgl+IROlmua4NAP4KwDoALQA73f1zZjYK4CsANqK9ZNcD7s41LwA5GFZnwxJLeXKR9ivNh+WLoX5+7npwgEtbf1jhtdZWnAjLigBQORZOwMgdPET7/EqZS1vHVhSp7ZuRpJ9p4zJgJReW2J7/u3+kfVYZT6i5+xRPcsm9wZN+Bs+cCreXeYLL7+3jh8/YKz+ithUlnqQzOBOuGZh3PoZW5Ulhto5Ln7aFy8StQV53MTsfXm4sM83Hw/vGw4ZMeNyDL+3iNQ0An3D3bQDuAvAHZnYjgIcAPOnuWwA82flfCPE2Ycngd/cT7v5C5+85APsArAdwP4DHOi97DEB4ZUYhxBXJed3zm9lGALcD+AmAte5+AmifIADw74dCiCuOrh/vNbNBAN8A8HF3nzXrql4AzGwHgB0AUMxrflGIK4WuotHM8mgH/hfd/Zud5gkzG+/YxwEEZ8Pcfae7b3f37fmsgl+IK4Ulo9Hal/hHAOxz94fPMj0O4MHO3w8C+Pald08IcblYsoafmb0bwD8C2IO21AcAn0T7vv+rAK4BcBjAb7p7eC2uDmtGSv5v7wlnYA2ORurZkaWO1r7Oa7d99DCXf7KbNlNb7lou19iPfxxs98P7eB9wOQ8tvrTSqdHwEk4AcGZojNrmC+HbseuKg7TP6Aq+PevjMqAV+F2j94f3lx3mfmRXcz/Qz6Vb7+c1GVu5sLTcbHA5r5Xht7S5Ub7EWjbDxwp5nkXYIrvzp57i2/ve/wk2//NDr+L58mJX9+RL3vO7+w8AsI2Fq0UKIa54dBMuRKIo+IVIFAW/EImi4BciURT8QiRKTwt45vM5XE3knHyeyyTNVliOvHf/Au1TGOKSTGbFWmrDnheoyU4dC7ff/Eu8z2284CM2rKem9SPhZc0AYH2Ry0aohLMIW6e5LAqSgQcATVIoEgAyfVy2s1ZYSmvO8+xNP8CX//ICv065cR+9GrZ5tcz7RKS+WqTQbLbE5Vms5Lbm1eFjNbuZFxLNfuR3wobP/Xfuwznoyi9Eoij4hUgUBb8QiaLgFyJRFPxCJIqCX4hE6anUl8tkMNo/ELQVc7yoZv/EbLD9+vlIocX5N6itefQ71La4jsuAmRveETbcsIX2wSouDWUmDlJb66dccsxOz1Fbs1oJtu93LosOEzkMAEbL4e0BQLHGMydbxfChZXVeOBN17ocVeHZkC5FinGR/mWwkIzGyPUSKpzb5UMEiRVJLpbB0e7TJx2OBXLYrp89wJ85BV34hEkXBL0SiKPiFSBQFvxCJouAXIlF6OtvvLUe9Gk48qVX5LOrWV8JJKSXnM6iNBl8WqgE+i1qaDi+dBAD9p6eD7f7Ms7SPt7gf9ciSUfVIbUWLnLMtG05K2Zjlako+ww+DrEeSZpzP9mcQ/mxifSxiQ4uPVaRyHuDh8ciQZLF2n8jYW+x6yW31iILwMEkk+lJkV7PExaMN/nmdi678QiSKgl+IRFHwC5EoCn4hEkXBL0SiKPiFSJQlpT4z2wDgrwCsQ3u5rp3u/jkz+zSAjwJ4swDcJ939u7FtZXNZjIyGa/g1ZrgUMn4oLL/VFsMJPwAQW4YsG1F5KhVez+6H+bBctrCe19uzGpf6xud4JsjmeW4zuoASgEZ4HPPnIQGdTZNIZW0/OM6skU4RoW+JfcWIbTVMM7IziyT2FCKe/HVkabPPDIeXG9v6Dr6s3IZi2Mkzz7xM+5xLNzp/A8An3P0FMxsC8LyZPdGxfdbd/1vXexNCXDF0s1bfCQAnOn/Pmdk+ALzsrBDibcF53fOb2UYAt6O9Qi8AfMzMdpvZo2bGv/sKIa44ug5+MxsE8A0AH3f3WQCfB3A9gNvQ/mbwGdJvh5k9Z2bPzS3y4htCiN7SVfCbWR7twP+iu38TANx9wt2b3n5Y+wsA7gz1dfed7r7d3bcP9UcWmxBC9JQlg9/MDMAjAPa5+8NntY+f9bIPAth76d0TQlwuupntvxvA7wLYY2a7Om2fBPBhM7sNbaXlEIDfX2pDmUwGpVJY1sj9iEsUI9PhbLpqRFqJyWE147Y/6ue14nZtWBNsv2bbVtpn9bqN1Hb6tZeobfMPeKbgf4zU3MuS992KnOdjUllkqNC08x//TFSXi22PE9umkzcQfc+RveVaXDqciYzHV/I81DaNh+tGPvBvfoP2GRgIH6d7Xns42B6im9n+HyA89lFNXwhxZaMn/IRIFAW/EImi4BciURT8QiSKgl+IROl5Ac/aYlimeufrPEMvVww/HGTlcDHQNjz76nuFPmr7/ih/SvmWVYPB9gLmaZ+xQb6vylh4ewDwnQ2rqe3Og+GCpgDwHlKYMrIAFQqRDMhYTlw20u9ChMWYj5HkwgsitrlYQdAj145S2+Eyz+A8FhnIW8iSbq8eeoX2GVs5HGyv1rt/ilZXfiESRcEvRKIo+IVIFAW/EImi4BciURT8QiRKT6U+ZHLI9oelkmffxTPj7NWwrFH62au0z3CTCza7MlxUyvEl7VAikuM1AwO0T+3063x7ziXC4RUrqO0fSmeo7d758HvLRdYFjGW4XfgBEt7qBe/rArU+X6K8ZwiL9OmrcHn5uPNraabIs0XHSCZpa+Eg7VOrhCVkr/PCr7/gU9evFEL8k0LBL0SiKPiFSBQFvxCJouAXIlEU/EIkSk+lPjOgUAinN01cHc5sAoCvHQ/LVC+s4RJbY4ZLHj9rctnLWvx8WBgKy5Tr1oQLMLa3t0htP1/gpcxr1TK1nXb+sU2NhyXCya030T75Ji8ImotIbJlmZD1EZotVBI3lELYiUmXm/Ffya5E1DQEgE7km9s/xz7N2dD+12QCXnhukKOimkXW0T6sZziDMZbq/nuvKL0SiKPiFSBQFvxCJouAXIlEU/EIkypKz/WZWAvA0gGLn9V9390+Z2XUAvgxgFMALAH7X3aMFxLKZLAYGwjPmxRKfcf6HUvgc9ePILPV8hs8c5yIV3IZmeS3BfF+4vt/4TffQPgtnTlPbySNPUdt8lc9GP9/gSsZfVMKzykdOH6d9spHJ8kKGz1IXjNtaZAY+m+V9LKoERJbyiigSbOkty/LrXnSpt2Gu0Lya4/08ImTMNcNhWOvnNR5LRWLLdb8YbjdX/iqAe939VrSX477PzO4C8CcAPuvuWwBMAfhI13sVQiw7Swa/t3kz9zTf+XEA9wL4eqf9MQAfuCweCiEuC13d85tZtrNC70kATwB4HcC0u7/5vfsogPWXx0UhxOWgq+B396a73wbgagB3AtgWelmor5ntMLPnzOy5mXn+1JoQorec12y/u08D+HsAdwEYMbM3ZyquBhCcUXL3ne6+3d23r4gsYCGE6C1LBr+ZrTazkc7ffQD+FYB9AJ4C8Budlz0I4NuXy0khxKWnm8SecQCPmVkW7ZPFV939f5vZywC+bGb/BcBPATyy1IbyhQKuujo8NeB5LlHcXQ7XurthfA3ts1DhcliryXWXQxO8Pt7evXuC7VtvuIP2GRzgcs0bJ6epbWZyktqqfVxS+otMWG3NHOH14OYqXKGt12MJMBFpi7VHSuqZcWOsEl9MIGRXt1guUCEi2Y0M8gS0kyTZBgDqU1xCPjk5F+5jfF+brr092F4oPE77nMuSwe/uuwH8wp7c/QDa9/9CiLchesJPiERR8AuRKAp+IRJFwS9Eoij4hUgU85j2cql3ZnYKwM87/64CwFPeeof8eCvy46283fy41t1Xd7PBngb/W3Zs9py7b1+WncsP+SE/9LVfiFRR8AuRKMsZ/DuXcd9nIz/eivx4K/9k/Vi2e34hxPKir/1CJMqyBL+Z3Wdmr5rZfjN7aDl86PhxyMz2mNkuM3uuh/t91MxOmtnes9pGzewJM/tZ53e4Wujl9+PTZnasMya7zOz9PfBjg5k9ZWb7zOwlM/sPnfaejknEj56OiZmVzOwZM3ux48cfddqvM7OfdMbjK2bWfbXOEO7e0x8AWbTLgG0CUADwIoAbe+1Hx5dDAFYtw37fA+AOAHvPavuvAB7q/P0QgD9ZJj8+DeAPezwe4wDu6Pw9BOA1ADf2ekwifvR0TNDOUh7s/J0H8BO0C+h8FcCHOu1/BuDfX8x+luPKfyeA/e5+wNulvr8M4P5l8GPZcPenAZybsH8/2oVQgR4VRCV+9Bx3P+HuL3T+nkO7WMx69HhMIn70FG9z2YvmLkfwrwdw5Kz/l7P4pwP4vpk9b2Y7lsmHN1nr7ieA9kEIgFcqufx8zMx2d24LLvvtx9mY2Ua060f8BMs4Juf4AfR4THpRNHc5gj9UJmW5JIe73f0OAL8K4A/M7D3L5MeVxOcBXI/2Gg0nAHymVzs2s0EA3wDwcXfnpW9670fPx8QvomhutyxH8B8FsOGs/2nxz8uNux/v/D4J4FtY3spEE2Y2DgCd3yeXwwl3n+gceC0AX0CPxsTM8mgH3Bfd/Zud5p6PSciP5RqTzr7Pu2hutyxH8D8LYEtn5rIA4EMAui88dokwswGzdpE0MxsA8D4Ae+O9LiuPo10IFVjGgqhvBluHD6IHY2LtdboeAbDP3R8+y9TTMWF+9HpMelY0t1czmOfMZr4f7ZnU1wH8p2XyYRPaSsOLAF7qpR8AvoT218c62t+EPgJgDMCTAH7W+T26TH78NYA9AHajHXzjPfDj3Wh/hd0NYFfn5/29HpOIHz0dEwC3oF0UdzfaJ5r/fNYx+wyA/QC+BqB4MfvRE35CJIqe8BMiURT8QiSKgl+IRFHwC5EoCn4hEkXBL0SiKPiFSBQFvxCJ8n8Bk39i95AyJoUAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.imshow(img.permute(1, 2, 0))\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([3, 32, 32, 50000])" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "imgs = torch.stack([img for img, _ in cifar10], dim=3)\n", "imgs.shape" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.4915, 0.4823, 0.4468])" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "imgs.view(3, -1).mean(dim=1)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.2470, 0.2435, 0.2616])" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "imgs.view(3, -1).std(dim=1)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Normalize(mean=(0.4915, 0.4823, 0.4468), std=(0.247, 0.2435, 0.2616))" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "transforms.Normalize((0.4915, 0.4823, 0.4468), (0.2470, 0.2435, 0.2616))" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n" ] } ], "source": [ "cifar10 = datasets.CIFAR10('data', train=True, download=True,\n", " transform=transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.4915, 0.4823, 0.4468),\n", " (0.2470, 0.2435, 0.2616))\n", " ]))" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n" ] } ], "source": [ "cifar10_val = datasets.CIFAR10('data', train=False, download=True,\n", " transform=transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.4915, 0.4823, 0.4468),\n", " (0.2470, 0.2435, 0.2616))\n", " ]))" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAEGZJREFUeJzt3X+s1fV9x/Hnu/wQLKz8hjtAUUJT3FAgd8QG7bTbHJpu6JyrrjGYOK9baqoJTUZoNpjpktn4I6Rt7K5CSjsqMEVh1q0lxIa6TOtVEVAsUERFbvlRJaJO9MJ7f5wv2YV+P59z7jnf8z0XPq9HQu45n8/5fr9vv97XPed8v9/P52vujoik51OtLkBEWkPhF0mUwi+SKIVfJFEKv0iiFH6RRCn8IolS+EUSpfCLJGpgIwub2TxgGTAAeNjd/6XK63U5YSLGDB2U2374fz8puZJ8F55nwb4PPg7/mh74dXidQ0eE+0ZH+gYPyW8ffm54mZ2v5bd/fAx6ejz8H9eL1Xt5r5kNAHYCfwLsA54HbnL3VyPLKPyJuHXGxNz25dveLrmSfI9+75xg37NvHQv23fvP4XVe8hfhvpv/PNw3aXp++5WzwstcNTe/fecr8OEHtYW/kY/9c4Dd7r7H3T8GVgPzG1ifiJSokfBPBN7q9Xxf1iYiZ4BGvvPnfbT4rY/1ZtYBdDSwHRFpgkbCvw+Y3Ov5JGD/6S9y906gE/SdX6Q/aeRj//PANDO7wMwGAzcCG4opS0Sare53fnfvMbM7gJ9QOdW3wt1fKawyOaP1l6P6gwPt0yZ9K7jM9R2zg31Pb7482Hd15Ih+++fDfa+9ld/+0o7wMlMCZwj27gkvc7qGzvO7+1PAU42sQ0RaQ1f4iSRK4RdJlMIvkiiFXyRRCr9Iouoe2FPXxnSRj5zhbv/rcN/7kZF7gYF7AAxvy28/2hNeZvl3Ax1HwD9p/sAeETmDKfwiiVL4RRKl8IskSuEXSVRD1/aLpGbLtnBfaLANwLOvh/te35Xf/mGskCOxztronV8kUQq/SKIUfpFEKfwiiVL4RRKl8IskSgN7RM4y7hrYIyIRCr9IohR+kUQp/CKJUvhFEqXwiySqoVF9ZrYXOAocB3rcvb2IokSk+YoY0nulux8uYD0iUiJ97BdJVKPhd+CnZvaCmXUUUZCIlKPRj/1z3X2/mY0DNprZa+6+ufcLsj8K+sMg0s8Udm2/mS0F3nf3eyOv0bX9Ik3W9Gv7zezTZjb85GPgKmB7vesTkXI18rF/PPC4mZ1cz4/c/b8KqUpEmk5DekXOMhrSKyJRCr9IohR+kUQp/CKJUvhFEqXwiyRK4RdJlMIvkiiFXyRRCr9IohR+kUQp/CKJUvhFEqXwiyRK4RdJlMIvkiiFXyRRCr9IohR+kUQp/CKJUvhFEqXwiyRK4RdJlMIvkiiFXyRRVcNvZivM7KCZbe/VNsrMNprZruznyOaWKSJFq+Wd//vAvNPaFgGb3H0asCl7LiJnkKrhd/fNwDunNc8HVmaPVwLXFlyXiDRZvd/5x7t7N0D2c1xxJYlIGRq5RXdNzKwD6Gj2dkSkb+p95z9gZm0A2c+DoRe6e6e7t7t7e53bEpEmqDf8G4AF2eMFwPpiyhGRspi7x19g9ghwBTAGOAAsAZ4A1gLnAW8CN7j76QcF89YV35iINMzdrZbXVQ1/kRR+kearNfy6wk8kUQq/SKIUfpFEKfwiiVL4RRKl8IskSuEXSZTCL5IohV8kUQq/SKIUfpFEKfwiiWr6ZB7SP8yP9Gk8dpr0zi+SKIVfJFEKv0iiFH6RRCn8IonS0f6zzDcD7d/47zuDy4yeuyzYV3ViRjlj6Z1fJFEKv0iiFH6RRCn8IolS+EUSpfCLJKqW23WtAL4EHHT338/algK3AYeyly1296eqbkx37GmZRyN9188K9615Kdz35atHB/vsP39TvShpiiLv2PN9YF5O+wPuPjP7VzX4ItK/VA2/u29G13qInHUa+c5/h5ltNbMVZjaysIpEpBT1hv9BYCowE+gG7gu90Mw6zKzLzLrq3JaINEFd4Xf3A+5+3N1PAA8BcyKv7XT3dndvr7dIESleXeE3s7ZeT68DthdTjoiUpeqoPjN7BLgCGGNm+4AlwBVmNhNwYC9wexNrlD5Y/eTW3PYtK/41uMx1674b7Hs2sq0bIqfznhiT337t4cgKI+bPmBjsW7/t7fpWmriq4Xf3m3KalzehFhEpka7wE0mUwi+SKIVfJFEKv0iiFH6RRFUd1VfoxjSqr+nq+v+58mfBLrvlymDf4Mgqjz18a277P/xN+ERRaPJRgDcevjvY97VVq4N9659+NbLWvhsX6Ytd4/7LQquIK3JUn4ichRR+kUQp/CKJUvhFEqXwiyRK4RdJlE71FSD2HzUl0vdGwXXE+P73w51f//tg1+d+FB7xFzt99WSg/fHIMh9F+h6J9J2I9P3upPz2FUfCy/zp9PDpTYjsx2lTw32vRyY0/Z+Nke31TTvQpVN9IhKj8IskSuEXSZTCL5IohV8kUTraf5qiC4wNK/m9grcV853LLwr2Dfx5uMorIwe+P/vj2PmKYYH28Hx7du7FkfWFjQoc0Qf4Ws/43PYlk/PbAfi38BkOPntZjVX1wVV5M+UBG8MDlkJ0tF9EqlL4RRKl8IskSuEXSZTCL5IohV8kUVVP9ZnZZOAHwAQqYyg63X2ZmY0C1lAZu7IX+Ct3f7fKuvrFqb5+UQTwt5G+8M21ihebl+5AdMnYDZ966qpFGlP0qb4eYKG7TwcuBb5qZhcBi4BN7j4N2JQ9F5EzRNXwu3u3u7+YPT4K7AAmAvOBldnLVgLXNqtIESlen77zm9kUYBbwHDDe3buh8geC+KdHEelnqt6l9yQzGwY8Btzl7u+Z1fS1AjPrADrqK09EmqWmd34zG0Ql+KvcfV3WfMDM2rL+NuBg3rLu3unu7e7eXkTBIlKMquG3ylv8cmCHu9/fq2sDsCB7vABYX3x5ItIstZzquwz4ObCN/58ubTGV7/1rgfOAN4Eb3P2dKusq9Czb3EjfM0VuSMox4fJw3/TZkb7zwn0jA6P33o2cxBwa+TZ89Zciy4VGMgJjIofEQpubOiS8DMdyW/tyqq/qd353fwYIreyPatmIiPQ/usJPJFEKv0iiFH6RRCn8IolS+EUSVeoEnoPNfGygL9QO4Rsk7W6wnnJETqhMvz3cF5s5Mzb55OuBCTLXRSaDPPxEuC8qcootOB4w/xTV2eEz4a4Jnw/3Lfyz/PZdkVuD7dqZ29zetZ6u9w5pAk8RCVP4RRKl8IskSuEXSZTCL5IohV8kUaWe6htr5vMDfZMjy30u0P7lBuspxcA/CPf1PF9eHZIE3atPRKpS+EUSpfCLJErhF0mUwi+SqFKP9o8w8ysCfbGbOz3ZhFpE+ouZgfaX61yf62i/iMQo/CKJUvhFEqXwiyRK4RdJlMIvkqiqd+wxs8nAD4AJVG7X1enuy8xsKXAbcCh76WJ3fyq2rt8BQjPTHam14hb6MNC+PbJMbAdHbkAlZ5kbI331ntJrVC236O4BFrr7i2Y2HHjBzDZmfQ+4+73NK09EmqWWe/V1A93Z46NmtgOY2OzCRKS5+vSd38ymALOo3KEX4A4z22pmK8xsZMG1iUgT1Rx+MxsGPAbc5e7vAQ8CU6lcndgN3BdYrsPMusysKzILuYiUrKbwm9kgKsFf5e7rANz9gLsfd/cTwEPAnLxl3b3T3dvdvT1y93IRKVnV8JuZAcuBHe5+f6/2tl4vu474QW8R6WdqOdo/F7gZ2GZmW7K2xcBNZjYTcGAvELn3VMXggTBlTH7fiF/XUEkJahoO1WLljcOUoqypY5mvzLo22DdjRv4x92//eG3N66/laP8z5Gciek5fRPo3XeEnkiiFXyRRCr9IohR+kUQp/CKJKnUCz/PNfHGgr+p5wgKtjPTdUvC2Yn9dT9S5ztgosIvrXKc07s1I3/kFb+vcQPtHwHFN4CkiMQq/SKIUfpFEKfwiiVL4RRKl8IskqpZRfYUZMBCGBUb1LYuM6ruz4DpuKXh9MfWezou5JNKnEX+t82CJ2wpNJtsXeucXSZTCL5IohV8kUQq/SKIUfpFEKfwiiSr1VN+gQTChLb/vh5FTfXcH2t9puKJiXB/pi+3geiZ1lP6ru+D1/WGk76NAe1+m0NY7v0iiFH6RRCn8IolS+EUSpfCLJKrq0X4zGwJsBs7JXv+ouy8xswuA1cAo4EXgZnf/OLauoed+ihkzhub2TXrpg+ByP6lWZIvd9u3Vwb7tG/4j2Ldm46rCa/lMoP29wrckzRa7g92UIfntA47Vvv5a3vmPAV9090uo3I57npldCtwDPODu04B3gVtr36yItFrV8HvF+9nTQdk/B74IPJq1rwTCdxUUkX6npu/8ZjYgu0PvQWAj8CvgiLv3ZC/ZB+TfNlRE+qWawu/ux919JjAJmANMz3tZ3rJm1mFmXWbW9ZuPNNWESH/Rp6P97n4E+BlwKTDCzE4eMJwE7A8s0+nu7e7ePnpITfcSEJESVA2/mY01sxHZ46HAHwM7gKeBv8xetgBY36wiRaR4tQzsaQNWmtkAKn8s1rr7k2b2KrDazL4JvAQsr7qx8WMZt/AruX13j30iuNz2+/bktj9XtfRyLLknfKpv5oxyb6ClU3pnj0ORvnuW5Odl93cW1rz+quF3963ArJz2PVS+/4vIGUhX+IkkSuEXSZTCL5IohV8kUQq/SKLMvbyr7szsEPBG9nQMcLi0jYepjlOpjlOdaXWc7+5ja1lhqeE/ZcNmXe7e3pKNqw7VoTr0sV8kVQq/SKJaGf7OFm67N9VxKtVxqrO2jpZ95xeR1tLHfpFEtST8ZjbPzH5pZrvNbFErasjq2Gtm28xsi5l1lbjdFWZ20My292obZWYbzWxX9nNki+pYamZvZ/tki5ldU0Idk83saTPbYWavmNmdWXup+yRSR6n7xMyGmNkvzOzlrI5/ytovMLPnsv2xxswGN7Qhdy/1HzCAyjRgFwKDgZeBi8quI6tlLzCmBdv9AjAb2N6r7VvAouzxIuCeFtWxFPh6yfujDZidPR4O7AQuKnufROoodZ8ABgzLHg+iMnr9UmAtcGPW/j3g7xrZTive+ecAu919j1em+l4NzG9BHS3j7pv57fuMzqcyESqUNCFqoI7SuXu3u7+YPT5KZbKYiZS8TyJ1lMormj5pbivCPxF4q9fzVk7+6cBPzewFM+toUQ0njXf3bqj8EgLjWljLHWa2Nfta0PSvH72Z2RQq80c8Rwv3yWl1QMn7pIxJc1sR/ryJ/Fp1ymGuu88Grga+amZfaFEd/cmDwFQq92joBu4ra8NmNgx4DLjL3Vs2KVFOHaXvE29g0txatSL8+4DJvZ4HJ/9sNnffn/08CDxOa2cmOmBmbQDZz4OtKMLdD2S/eCeAhyhpn5jZICqBW+Xu67Lm0vdJXh2t2ifZtvs8aW6tWhH+54Fp2ZHLwcCNwIayizCzT5vZ8JOPgauA7fGlmmoDlYlQoYUTop4MW+Y6StgnZmZU5oDc4e739+oqdZ+E6ih7n5Q2aW5ZRzBPO5p5DZUjqb8CvtGiGi6kcqbhZeCVMusAHqHy8fETKp+EbgVGA5uAXdnPUS2q44fANmArlfC1lVDHZVQ+wm4FtmT/ril7n0TqKHWfABdTmRR3K5U/NP/Y63f2F8Bu4N+BcxrZjq7wE0mUrvATSZTCL5IohV8kUQq/SKIUfpFEKfwiiVL4RRKl8Isk6v8AE+A/IhEVAkwAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "img, _ = cifar10[99]\n", "\n", "plt.imshow(img.permute(1, 2, 0))\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "label_map = {0: 0, 2: 1}\n", "cifar2 = [(img, label_map[label]) for img, label in cifar10 if label in [0, 2]]\n", "cifar2_val = [(img, label_map[label]) for img, label in cifar10_val if label in [0, 2]]" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'n_output_features' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3072\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m512\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTanh\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m nn.Linear(512, n_output_features))\n\u001b[0m", "\u001b[0;31mNameError\u001b[0m: name 'n_output_features' is not defined" ] } ], "source": [ "import torch.nn as nn\n", "\n", "model = nn.Sequential(\n", " nn.Linear(3072, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, n_output_features))" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "def softmax(x):\n", " return torch.exp(x) / torch.exp(x).sum()" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.0900, 0.2447, 0.6652])" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = torch.tensor([1.0, 2.0, 3.0])\n", "\n", "softmax(x)" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(1.)" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "softmax(x).sum()" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.0900, 0.2447, 0.6652],\n", " [0.0900, 0.2447, 0.6652]])" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torch.nn as nn\n", "\n", "torch.manual_seed(123)\n", "\n", "softmax = nn.Softmax(dim=1)\n", "\n", "x = torch.tensor([[1.0, 2.0, 3.0],\n", " [1.0, 2.0, 3.0]])\n", "\n", "softmax(x)" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " nn.Linear(3072, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 2),\n", " nn.Softmax(dim=1))" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAGXFJREFUeJztnXt4VeWVxt8lAUEDRQhICtiAYos3Lg14G6yCqPg4g1i1OtWh1hHt6Iyd6Uyljq20tR3tU7XaUgWrI/YRFeu99YYpLVQZJCImYqxcREECJGKEiBhD1vxxDhriXisn+5yzT9Lv/T1PniTrzbf3OvucN/ucvfa3PlFVEELCY59CJ0AIKQw0PyGBQvMTEig0PyGBQvMTEig0PyGBQvMTEig0PyGBQvMTEihF2QwWkdMA3AKgG4DfqOr13t/3LhHtXxatvfWGM3Df6PA+Pe0h3aWbqfXaz37YBxSXmFofDIyMFzn/Qxux3dQ27lhjasW97TsvP2cqQHcjvssZYxxeAP7Zwbs3dLcR398Zkw8ajHiTM+ZD8ygC3qPevqPZ1Jo+dDa509Es3jfiuwFtUclkExL39l4R6QbgDQCTAWwEsBzA+ar6mjWmrFz0+5XR2j+f4uxsWHS4z0jbxAOKbIuMPqq/qU2bcLGpTZbLI+MDnZf0C3jO1L5bcYapHTvpI1OzRwEDjPhqZ4xxeAEAxY7m/UNpNOLjnTFxaXG0J4z4emdMDQabWjNsgz9XscXU3qpxdviyo1k8acTrAf04M/Nn87Z/PIA1qrpOVZsA3A9gahbbI4QkSDbmHwxgQ6vfN6ZjhJAuQDbmj3pr8ZnPECIyQ0QqRaRyR10WeyOE5JRszL8RwNBWvw8BsKntH6nqXFUtV9Xy3tYHUkJI4mRj/uUARojIMBHpAeA8AI/nJi1CSL6JXepT1WYRuQLAM0iV+u5S1VXemG5wrh5/1Rl4WXR4+0j7yuv2I981tTc329qaF+fY2sTow3X+2JPMMcc5hblfT+praqthXzk2CiYA7CvwQ5wx9lEE6h3Nzj7uVf3hjnaUqVRiuan9z2PvRMaLh0aGU5TYj7riVrsK02OMs007Rbse6WE90R0o3mVV51fVJ2EXHQghnRje4UdIoND8hAQKzU9IoND8hAQKzU9IoGR1tb+jbAHwS0ObdKk9rsKoD44a49VW7HLeK99729YeX2drU74TGa+eZZehJo+vMjWvwuNMWMRGR7MqSlOcMYMc7VhH64MDHdU6/l5h0S6jvYAzTW3hY3Y5ddmZ86KFs+wsjv6VnQeOtKWmF20Nbzqa5cJFzpgcwDM/IYFC8xMSKDQ/IYFC8xMSKDQ/IYGS6NX+D94Bnr8mWjv4Onvc5V+Pjs9+yOl/5LVNGudo3rzEp6LDSy+3r+j/vbM572r/TY7mMdmIe9fYvTZefdz+LNE9DQHgMuPRTcah5phxTjfBOqehWPXQC0wNMK72OwfkS6W21jDB1v7qXdF3tlmo2TE88xMSKDQ/IYFC8xMSKDQ/IYFC8xMSKDQ/IYGSaKkPmwH8JFpa6wyb/Q+G4M1+cRrMHeysDrTWKxHOjw5vchrdfWOxsz0n/4Exl7axFhuzSoAAcKi7AJi97NkfPtus+RN2GU/AMNiTsZbC7oV4ntfkcawtmY+8bKE5YqG55hSwabazK2/G1QZHs5Y3yjM88xMSKDQ/IYFC8xMSKDQ/IYFC8xMSKDQ/IYGSValPRNYD2AFgN4BmVS2PvbF7Hc0qvzn91PBNW2p2lmo6+le2tsxaImm1k4eHM4Ow8WZb+5eDbM1aCHmJk0Y13je1Mkdb6WxznNHf7z38yRxzH86xNxi1JnRGnB0dbt7fHLFp9qP25ryZewc4WidcoToXdf6TVNVb0o0Q0gnh235CAiVb8yuAZ0XkJRGZkYuECCHJkO3b/uNVdZOIDASwUEReV9W9bmhN/1PgPwZCOhlZnflVdVP6+1YAjyBiWXZVnauq5VldDCSE5JzY5heR/UWk956fAZwC4NVcJUYIyS+iqvEGigxH6mwPpD4+zFdVY87eJ2Pi7ezHRvweZ4y3zpQzm+6Lc2ztEiN+hLOreqcB5owX3zG1ndX2NkddbGvWBDFv1uTxjvYtR7NmEAJAKaLrkdXYbY65oMqpYY5yOmfiZ44WA++BTXQ0u8cosNTRrBl/MWf7qWpGhdHYn/lVdR2AUXHHE0IKC0t9hAQKzU9IoND8hAQKzU9IoND8hARK7FJfrJ3FLfVNMeK9nTHeTLv3HM1qFgoAJxhxZ+2/rzrVK6fHKO5c4YheM0ij8efhzlpxzlw6t4zpTI5EmRGvQX9zzIkVh9kbPNmaUgkAy23JKr+NdDZ3rqN5DV5rHc1Y5zEfZFrq45mfkECh+QkJFJqfkECh+QkJFJqfkEBJdrkuDy+TNUb8n2Lua4GjPexo1gpPZfaQhy53tucsybWPvaoVBjnLU40w4hc6aRziaB5eWzpLa8a79qCFh8ZL5G7nar/B1Om2NsgZN+dSR7RXAOuU8MxPSKDQ/IQECs1PSKDQ/IQECs1PSKDQ/IQESucp9TU72g4j7k2y8HB6+LlHZLIR9yYYbXa0eU4aV9rahO7ONg28JZWsh9UecQ6/Nz0HN9iTfuD0Qrx9+mmmdgiejox7x2O9o7kDvddwJ4RnfkICheYnJFBofkICheYnJFBofkICheYnJFDaLfWJyF0AzgCwVVWPSMf6AXgAqfls6wGcq6peZ7zssMpl850xzqw4c+ob4Pf+O8CIe7MLvRmEznJMTU6fvvXDbW2IEe+JA80xT2GLqXl9Bp9wNGsipo+3N3u9qzKnh5/1lHn5NcOeXTjqyjdM7ZUjnY3+0NGs16pXkh5gxP/sjGlDJmf+uwG0LaTOBFChqiMAVKR/J4R0Ido1v6ouBrCtTXgqPr1FZR6AM3OcFyEkz8T9zH+gqtYCQPr7wNylRAhJgrzf3isiMwDMyPd+CCEdI+6Zf4uIlAJA+vtW6w9Vda6qlqtqecx9EULyQFzzPw5gTxe06QAey006hJCkaHe5LhG5D8CJAEoAbAFwLYBHkSpiHQTgbQDnqGrbi4JR20pubTAPr0OjNwvPKuV4H2p6OZoznc5b5usc7O9sNHp9qhKn1FePKlP7P2dPv/jQEW804vc4Y1b/xtZG2p1Qv/baR6Y2wYh/CUeZY8bhFlNrxhxTK4L9pC3Hwaa22Tj+jbDLiq/r2sj4/HEbsaXyo4yW62r3M7+qnm9IkzLZASGkc8I7/AgJFJqfkECh+QkJFJqfkECh+QkJlM7TwLOz4M2kqjbi33PG/NyWpjvlPGs2GgCsh93ossQoKRU5T/VKZ1+/8O7g+KOjWY0uvVmTWGdLk+0npgF2qc+q3BY75c0ifNfUSvG2qR1q1jeBSfi6qVmtULfhHXNEPzk5Mr4Emd9LxzM/IYFC8xMSKDQ/IYFC8xMSKDQ/IYFC8xMSKF271Odl762b1uBo7mJyBk4jTr+0Zdf6ijDN2Z3dKXKIMVutHu+aY5asMCVg6UJby/m6dT81laMP2NfU/t3ZopWi18BzidMQ1Ht5/AgXmNpwZ5vA5yOjy/GBOeJUnORsLzN45ickUGh+QgKF5ickUGh+QgKF5ickULr21f5YV5QR74p+XKJb6qWk9+yr2x/uOs/UDinpZm/UeEaLnIrEJWPbLsj0KReNtcetsZs2o/qZ6Ek6f1hwg71BPGoqE5rtyTunftJL9rPc+MnaMnsTZyUsAKh1tDcdbYjTF9B6arz+iZUf3h0Zr23xmlDuDc/8hAQKzU9IoND8hAQKzU9IoND8hAQKzU9IoLRb6hORuwCcAWCrqh6Rjs0CcAmAuvSfXa2qT+YryU6PU87r1/gDU3twtr2E04C+djmvYYS9v0aj0rNmtV0qKxthT5rp2dfe14SJ9srsg46L1p4660pzTMvDdqlvqVNHe80o5wHAaCM+DIPNMRuc3nm9Hcs0w37O/tfpMzjEiE8xRwA9e0VP4Jq/z/vOqL3J5Mx/N4CoQvDNqjo6/RWu8QnporRrflVdDKDdRTgJIV2LbD7zXyEiVSJyl4h4naYJIZ2QuOa/DcDBSH2kqoW9IDNEZIaIVIpIZcx9EULyQCzzq+oWVd2tqi0A7gAw3vnbuaparqqZryZACMk7scwvIqWtfp0G4NXcpEMISYpMSn33ATgRQImIbARwLYATRWQ0AAWwHsCleczR5PNldolq2AnmmxEU7bIf9p8XLOp4IsP+w5S2vTnBHlf3liltHbG/qdVutstU26rfiBaqVpljVjXaveLQaJeOHho3xtR6jIkuY7Y87PQEdHjeWioNwK+dcb2NeJ1TzhvpbG+yM5W0r6N5bSOtjozj4c2AvCwy2gtfccbsTbvmV9XzI8J3ZrwHQkinhHf4ERIoND8hgULzExIoND8hgULzExIoiTbwHNz/8/jXqdElip4n2GWjnmMOi4yfNGy4OabYqvHAnYSHswZdYWoVt94fLVjlNQCofttJxC7nod5eQ2tb3YHOuOjGmXBKW0B/R3PW8lpiz1hsWmJt83POvhycUp/Xj/VpI772OmeQ16XTaWh66cW29hdnk1b+xzkNTe1EnLJtG3jmJyRQaH5CAoXmJyRQaH5CAoXmJyRQaH5CAiXRUt+gslJcdef3k9xlh6mpdxa1w7tG/PfxdubtqsYrv51tS32PjY43OGVFOOVIZz0+H+tYWfH4eGvamS9w75XvTRN0pvzNcZqdmlP3AKwaFh1/ovtSc8xPMDkyvt1JoS088xMSKDQ/IYFC8xMSKDQ/IYFC8xMSKIle7e8K1Fd7kymSxLsqPseWGqw+cnZ/OcCYsNSZcF6pqx5zxp0QHf7yTHvISxuc7RnLoQEAvHGnd3zcSxvtIbcZj6sjtRme+QkJFJqfkECh+QkJFJqfkECh+QkJFJqfkEDJZLmuoQDuATAIQAuAuap6i4j0A/AAgDKkluw6V1Xfy1+qHaMJm0yth1NGK6q2l6dqyiqjpPgbXUxphqMNdTSjwlntvFK/6PT367vD1mqcMmDPXra21eg3ebjd1hK7PoyOa4s9pi2ZnPmbAXxHVUcCOAbA5SJyGICZACpUdQSAivTvhJAuQrvmV9VaVV2R/nkHgBoAgwFMBTAv/WfzAJyZryQJIbmnQ5/5RaQMwBgAywAcqKq1QOofBICBuU6OEJI/Mja/iBQDeAjAt1U1454BIjJDRCpFpLKuri5OjoSQPJCR+UWkO1LGv1dVH06Ht4hIaVovhXFbsarOVdVyVS0fMGBALnImhOSAds0vIoLUJeQaVb2plfQ4gOnpn6cD8KZXEEI6GZnM6jsewIUAqkVkZTp2NYDrASwQkYuRagJ3Tn5SBLYZ8UZYS1MBDfqcqQ3CFlPbmWlSJFGOnm1ry56xtT7GqlbeC7/WaWl40UFHmdq0g6pMzZtTeY0x7JhJ9hirJeBrHbiK1675VfUvAMSQnfQIIZ0Z3uFHSKDQ/IQECs1PSKDQ/IQECs1PSKB0iQae/Yx4MYabYzb/8R1Te6p+iantV2znsdNbXotkz5SY4162pQNOjY5700+nHWRr52BfU+vpbHORox0/MTruTVa874Xo+LYOvEZ55ickUGh+QgKF5ickUGh+QgKF5ickUGh+QgKlS5T64lAybLCplU20OyOOqbbLgM//JHpu1pevsvN4yZb82tBqR5vvbTRBjnW0pTG2d40tTcbnTG30TPtlvMZo1rpc7X3tsqaxAbgJy03Nq1Q6y+5hgrG/OifHDW9Gx5s60GWWZ35CAoXmJyRQaH5CAoXmJyRQaH5CAiXRq/0tsHvkNRrLDwFAX2OpoyJ8YI4ZPtye9NO4Y7GpWVf0PWrmOOLpjuZ1Mh/R4TSSpyHGmCGO5iyFdd1Eexk1jHS2aVQQ9nEmcD1gXEkHADgTZ54+ztZOczY5wYg3OFWHhrOi40/d6OyoDTzzExIoND8hgULzExIoND8hgULzExIoND8hgdJuqU9EhgK4B8AgpKp1c1X1FhGZBeASfFqwulpVn/S2tQ+A/Qyt3ikb9TBKfVvxe3PMgw+cZ2pX2JL737DFiO/0Sl5xJ+EsjDkuSTpeFQWM5xIA8A1H2+xoXoO88dHhFq+Jnzcp6VxbWvtzW5ttzxfDGGOVy4tgT06r7hXdo7JbLpfrQuop/o6qrhCR3gBeEpE9L82bVdV5yISQzkoma/XVAqhN/7xDRGoA518SIaRL0KHP/CJSBmAMgGXp0BUiUiUid4nIATnOjRCSRzI2v4gUA3gIwLdVdTuA2wAcDGA0Uu8MIm8sFJEZIlIpIpV1dd79rISQJMnI/CLSHSnj36uqDwOAqm5R1d2q2gLgDhiXVlR1rqqWq2r5gAEDcpU3ISRL2jW/iAiAOwHUqOpNreKlrf5sGoBXc58eISRfZHK1/3gAFwKoFpGV6djVAM4XkdEAFMB6AJe2t6FG7MILqInUajesNcdVvxYd/+0iu2b3wLPtZRONVc7rVPybo92a431da0s9jrS1prMNwetNGBen/GbOFKx3xixwNK+YHXM5t18aM1qPNMp5AHBHVXT8Y2d2bFsyudr/FwBRkwvdmj4hpHPDO/wICRSan5BAofkJCRSan5BAofkJCZREG3hu/7geC2vvjtSqX7jXHNe4OrrksehlZ2deE8YuzqSv2VpFrkt982ypyZvNOM6I26tdxWeYow014t1j7itmOc9ryPqK8Tp+xGkIWmI8rroemafEMz8hgULzExIoND8hgULzExIoND8hgULzExIoiZb6PvpgG1a/GF3SK+ptz2AqMZowTnDWaKv4L1vrY0vY7mgWp06xtWeeirFBAJOsUhmAMWNsrcKa8Re3BLje0fo6mlXa8pp+eqVbjziNRE9ytH90tLgNWb3ZjMZaj9c7axeOOjU6/l63jDPimZ+QUKH5CQkUmp+QQKH5CQkUmp+QQKH5CQmUZEt973+MNU9Gl/SKrdlXADYaWQ5yymFTH7W1eqd5Y4OTx67F0fGlccs/DhXO7LeKmc5Aozv6frfbQ3bOcrbnNMc8/Ou2dohRnu3p7OoRY806AGjyOkYOcTTrud7ljHFKyHnBKnE6B6vamMnY4j2uNvDMT0ig0PyEBArNT0ig0PyEBArNT0igtHu1X0R6AlgMYN/03/9OVa8VkWEA7gfQD8AKABeqapO3rf49gYuMCR+vO1fZrfkjzc6i4IPG2trmFbZW41y5b4lch7gAeH3k7okO73T67X35x7a2wVlYedUNjnZKdHw/Z1LST6faWo2jPae29pa19FatPQaljuZUmGL393uv40OKjErAxx04nWfypx8BmKiqo5Bajvs0ETkGwA0AblbVEUilf3HmuyWEFJp2za8p9vxP657+UgATAfwuHZ8H4My8ZEgIyQsZvUkQkW7pFXq3AlgIYC2ABlXdM5N6I4DB+UmREJIPMjK/qu5W1dFI3Us1HtH3QEV+8hKRGSJSKSKVjXE/ExFCck6HrvaragOAPwE4BkBfEdlzwXAIgE3GmLmqWq6q5cXF2aRKCMkl7ZpfRAaISN/0z70AnAygBsAiAGen/2w6AOfObEJIZ0NUnToJABE5CqkLet2Q+mexQFV/JCLD8Wmp72UAF6jqR962RpeKPmvUBDZeta857r4bozc735ns0eBMzujrlHkaFtraTlvKPSWO5pWbYvYMNJngaM4kkj5GqW+7s4zaF75pa2dMsjVjLhMA4NZ10fFttziDnDIxnNKnu0Scl+TDRtzrTWhN1JoB6OsqzshPaLfOr6pVAD5TnVXVdUh9/ieEdEF4hx8hgULzExIoND8hgULzExIoND8hgdJuqS+nOxOpA/BW+tcS2B3WkoR57A3z2JuulscXVNUrLH5Coubfa8cilapaXpCdMw/mwTz4tp+QUKH5CQmUQpp/bgH33RrmsTfMY2/+ZvMo2Gd+Qkhh4dt+QgKlIOYXkdNE5K8iskZEvMWn8p3HehGpFpGVIlKZ4H7vEpGtIvJqq1g/EVkoIqvT3532pHnNY5aIvJM+JitF5PQE8hgqIotEpEZEVonIlel4osfEySPRYyIiPUXkRRF5JZ3HD9PxYSKyLH08HhCRHlntSFUT/UJqavBaAMMB9ADwCoDDks4jnct6ACUF2O8JSE0cfbVV7GcAZqZ/ngnghgLlMQvAfyZ8PEoBjE3/3BvAGwAOS/qYOHkkekwACIDi9M/dASxDqoHOAgDnpeO3A/hWNvspxJl/PIA1qrpOU62+7wfgNGb+20NVFwPY1iY8Fam+CUBCDVGNPBJHVWtVdUX65x1INYsZjISPiZNHomiKvDfNLYT5BwPY0Or3Qjb/VADPishLIjKjQDns4UBVrQVSL0IAAwuYyxUiUpX+WJD3jx+tEZEypPpHLEMBj0mbPICEj0kSTXMLYf6oLiOFKjkcr6pjAUwBcLmInFCgPDoTtwE4GKk1GmoBJLZUiYgUA3gIwLdVdXtS+80gj8SPiWbRNDdTCmH+jQBar89jNv/MN6q6Kf19K4BHUNjORFtEpBQA0t+3FiIJVd2SfuG1ALgDCR0TEemOlOHuVdU9ja0SPyZReRTqmKT33eGmuZlSCPMvBzAifeWyB4DzADyedBIisr+I9N7zM4BTALzqj8orjyPVCBUoYEPUPWZLMw0JHBMREQB3AqhR1ZtaSYkeEyuPpI9JYk1zk7qC2eZq5ulIXUldC+C/C5TDcKQqDa8AWJVkHgDuQ+rt48dIvRO6GEB/ABUAVqe/9ytQHr8FUA2gCinzlSaQx98h9Ra2CsDK9NfpSR8TJ49EjwmAo5BqiluF1D+aH7R6zb4IYA2ABwHsm81+eIcfIYHCO/wICRSan5BAofkJCRSan5BAofkJCRSan5BAofkJCRSan5BA+X+oAC6reFaYfAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "img, _ = cifar2[0]\n", "\n", "plt.imshow(img.permute(1, 2, 0))\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [], "source": [ "img_batch = img.view(-1).unsqueeze(0)" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.3700, 0.6300]], grad_fn=)" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out = model(img_batch)\n", "out" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([1])" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "_, index = torch.max(out, dim=1)\n", "\n", "index" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 1.]])" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "softmax = nn.Softmax(dim=1)\n", "\n", "log_softmax = nn.LogSoftmax(dim=1)\n", "\n", "x = torch.tensor([[0.0, 104.0]])\n", "\n", "softmax(x)" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ -inf, 0.0000]])" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.log(softmax(x))" ] }, { "cell_type": "code", "execution_count": 228, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-104., 0.]])" ] }, "execution_count": 228, "metadata": {}, "output_type": "execute_result" } ], "source": [ "log_softmax(x)" ] }, { "cell_type": "code", "execution_count": 229, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 1.]])" ] }, "execution_count": 229, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.exp(log_softmax(x))" ] }, { "cell_type": "code", "execution_count": 230, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " nn.Linear(3072, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 2),\n", " nn.LogSoftmax(dim=1))" ] }, { "cell_type": "code", "execution_count": 231, "metadata": {}, "outputs": [], "source": [ "loss = nn.NLLLoss()" ] }, { "cell_type": "code", "execution_count": 232, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.6509, grad_fn=)" ] }, "execution_count": 232, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img, label = cifar2[0]\n", "\n", "out = model(img.view(-1).unsqueeze(0))\n", "\n", "loss(out, torch.tensor([label]))" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0, Loss: 2.686693\n", "Epoch: 1, Loss: 2.695894\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Epoch: %d, Loss: %f\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/optim/sgd.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0md_p\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbuf\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 106\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 107\u001b[0;31m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mgroup\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'lr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md_p\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 108\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "\n", "model = nn.Sequential(\n", " nn.Linear(3072, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 2),\n", " nn.LogSoftmax(dim=1))\n", "\n", "learning_rate = 1e-4\n", "\n", "optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n", "\n", "loss_fn = nn.NLLLoss()\n", "\n", "nepochs = 100\n", "\n", "for epoch in range(nepochs):\n", " for img, label in cifar2:\n", " out = model(img.view(-1).unsqueeze(0))\n", " loss = loss_fn(out, torch.tensor([label]))\n", " \n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " print(\"Epoch: %d, Loss: %f\" % (epoch, float(loss)))" ] }, { "cell_type": "code", "execution_count": 97, "metadata": {}, "outputs": [], "source": [ "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "\n", "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)\n", "\n", "model = nn.Sequential(\n", " nn.Linear(3072, 128),\n", " nn.Tanh(),\n", " nn.Linear(128, 2),\n", " nn.LogSoftmax(dim=1))\n", "\n", "learning_rate = 1e-2\n", "\n", "optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n", "\n", "loss_fn = nn.NLLLoss()\n", "\n", "nepochs = 100\n", "\n", "for epoch in range(nepochs):\n", " for imgs, labels in train_loader:\n", " outputs = model(imgs.view(imgs.shape[0], -1))\n", " loss = loss_fn(outputs, labels)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " print(\"Epoch: %d, Loss: %f\" % (epoch, float(loss)))" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0, Loss: 0.483681\n", "Epoch: 1, Loss: 0.444471\n", "Epoch: 2, Loss: 0.433452\n", "Epoch: 3, Loss: 0.300507\n", "Epoch: 4, Loss: 0.441095\n", "Epoch: 5, Loss: 0.517073\n", "Epoch: 6, Loss: 0.305444\n", "Epoch: 7, Loss: 0.556281\n", "Epoch: 8, Loss: 0.333576\n", "Epoch: 9, Loss: 0.339125\n", "Epoch: 10, Loss: 0.661290\n", "Epoch: 11, Loss: 0.156344\n", "Epoch: 12, Loss: 0.228916\n", "Epoch: 13, Loss: 0.582214\n", "Epoch: 14, Loss: 0.319739\n", "Epoch: 15, Loss: 0.348913\n", "Epoch: 16, Loss: 0.298435\n", "Epoch: 17, Loss: 0.269637\n", "Epoch: 18, Loss: 0.264462\n", "Epoch: 19, Loss: 0.394460\n", "Epoch: 20, Loss: 0.183625\n", "Epoch: 21, Loss: 0.252047\n", "Epoch: 22, Loss: 0.164425\n", "Epoch: 23, Loss: 0.356711\n", "Epoch: 24, Loss: 0.356409\n", "Epoch: 25, Loss: 0.261684\n", "Epoch: 26, Loss: 0.195436\n", "Epoch: 27, Loss: 0.357038\n", "Epoch: 28, Loss: 0.107395\n", "Epoch: 29, Loss: 0.270044\n", "Epoch: 30, Loss: 0.307356\n", "Epoch: 31, Loss: 0.109826\n", "Epoch: 32, Loss: 0.110416\n", "Epoch: 33, Loss: 0.060309\n", "Epoch: 34, Loss: 0.121865\n", "Epoch: 35, Loss: 0.080135\n", "Epoch: 36, Loss: 0.187582\n", "Epoch: 37, Loss: 0.215502\n", "Epoch: 38, Loss: 0.099308\n", "Epoch: 39, Loss: 0.177017\n", "Epoch: 40, Loss: 0.274428\n", "Epoch: 41, Loss: 0.148291\n", "Epoch: 42, Loss: 0.165810\n", "Epoch: 43, Loss: 0.419964\n", "Epoch: 44, Loss: 0.171430\n", "Epoch: 45, Loss: 0.205333\n", "Epoch: 46, Loss: 0.049800\n", "Epoch: 47, Loss: 0.063027\n", "Epoch: 48, Loss: 0.096860\n", "Epoch: 49, Loss: 0.116223\n", "Epoch: 50, Loss: 0.028307\n", "Epoch: 51, Loss: 0.109180\n", "Epoch: 52, Loss: 0.029037\n", "Epoch: 53, Loss: 0.103832\n", "Epoch: 54, Loss: 0.065226\n", "Epoch: 55, Loss: 0.119214\n", "Epoch: 56, Loss: 0.086423\n", "Epoch: 57, Loss: 0.032513\n", "Epoch: 58, Loss: 0.016831\n", "Epoch: 59, Loss: 0.013779\n", "Epoch: 60, Loss: 0.025617\n", "Epoch: 61, Loss: 0.019802\n", "Epoch: 62, Loss: 0.061896\n", "Epoch: 63, Loss: 0.074194\n", "Epoch: 64, Loss: 0.036079\n", "Epoch: 65, Loss: 0.051228\n", "Epoch: 66, Loss: 0.091050\n", "Epoch: 67, Loss: 0.076752\n", "Epoch: 68, Loss: 0.035503\n", "Epoch: 69, Loss: 0.019295\n", "Epoch: 70, Loss: 0.070132\n", "Epoch: 71, Loss: 0.071184\n", "Epoch: 72, Loss: 0.027912\n", "Epoch: 73, Loss: 0.089483\n", "Epoch: 74, Loss: 0.023227\n", "Epoch: 75, Loss: 0.038895\n", "Epoch: 76, Loss: 0.010719\n", "Epoch: 77, Loss: 0.028945\n", "Epoch: 78, Loss: 0.026111\n", "Epoch: 79, Loss: 0.041060\n", "Epoch: 80, Loss: 0.019934\n", "Epoch: 81, Loss: 0.046343\n", "Epoch: 82, Loss: 0.035450\n", "Epoch: 83, Loss: 0.015880\n", "Epoch: 84, Loss: 0.024630\n", "Epoch: 85, Loss: 0.025230\n", "Epoch: 86, Loss: 0.030091\n", "Epoch: 87, Loss: 0.019569\n", "Epoch: 88, Loss: 0.028038\n", "Epoch: 89, Loss: 0.024674\n", "Epoch: 90, Loss: 0.026119\n", "Epoch: 91, Loss: 0.032457\n", "Epoch: 92, Loss: 0.006373\n", "Epoch: 93, Loss: 0.011275\n", "Epoch: 94, Loss: 0.008130\n", "Epoch: 95, Loss: 0.010406\n", "Epoch: 96, Loss: 0.009017\n", "Epoch: 97, Loss: 0.022544\n", "Epoch: 98, Loss: 0.028746\n", "Epoch: 99, Loss: 0.008170\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "\n", "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)\n", "\n", "model = nn.Sequential(\n", " nn.Linear(3072, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 2),\n", " nn.LogSoftmax(dim=1))\n", "\n", "learning_rate = 1e-2\n", "\n", "optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n", "\n", "loss_fn = nn.NLLLoss()\n", "\n", "nepochs = 100\n", "\n", "for epoch in range(nepochs):\n", " for imgs, labels in train_loader:\n", " outputs = model(imgs.view(imgs.shape[0], -1))\n", " loss = loss_fn(outputs, labels)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " print(\"Epoch: %d, Loss: %f\" % (epoch, float(loss)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=False)\n", "\n", "correct = 0\n", "total = 0\n", "\n", "with torch.no_grad():\n", " for imgs, labels in train_loader:\n", " outputs = model(imgs.view(imgs.shape[0], -1))\n", " _, predicted = torch.max(outputs, dim=1)\n", " total += labels.shape[0]\n", " correct += int((predicted == labels).sum())\n", " \n", "print(\"Accuracy: %f\" % (correct / total))" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.794000\n" ] } ], "source": [ "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)\n", "\n", "correct = 0\n", "total = 0\n", "\n", "with torch.no_grad():\n", " for imgs, labels in val_loader:\n", " outputs = model(imgs.view(imgs.shape[0], -1))\n", " _, predicted = torch.max(outputs, dim=1)\n", " total += labels.shape[0]\n", " correct += int((predicted == labels).sum())\n", " \n", "print(\"Accuracy: %f\" % (correct / total))" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " nn.Linear(3072, 1024),\n", " nn.Tanh(),\n", " nn.Linear(1024, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 128),\n", " nn.Tanh(),\n", " nn.Linear(128, 2),\n", " nn.LogSoftmax(dim=1))" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " nn.Linear(3072, 1024),\n", " nn.Tanh(),\n", " nn.Linear(1024, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 128),\n", " nn.Tanh(),\n", " nn.Linear(128, 2))\n", "\n", "loss_fn = nn.CrossEntropyLoss()" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0, Loss: 0.830766\n", "Epoch: 1, Loss: 0.537586\n", "Epoch: 2, Loss: 0.520497\n", "Epoch: 3, Loss: 0.551923\n", "Epoch: 4, Loss: 0.613409\n", "Epoch: 5, Loss: 0.367955\n", "Epoch: 6, Loss: 0.461690\n", "Epoch: 7, Loss: 0.440201\n", "Epoch: 8, Loss: 0.330919\n", "Epoch: 9, Loss: 0.672060\n", "Epoch: 10, Loss: 0.214384\n", "Epoch: 11, Loss: 0.289531\n", "Epoch: 12, Loss: 0.242116\n", "Epoch: 13, Loss: 0.528731\n", "Epoch: 14, Loss: 0.319517\n", "Epoch: 15, Loss: 0.344509\n", "Epoch: 16, Loss: 0.336134\n", "Epoch: 17, Loss: 0.530772\n", "Epoch: 18, Loss: 0.260637\n", "Epoch: 19, Loss: 0.502093\n", "Epoch: 20, Loss: 0.118269\n", "Epoch: 21, Loss: 0.411113\n", "Epoch: 22, Loss: 0.205757\n", "Epoch: 23, Loss: 0.332014\n", "Epoch: 24, Loss: 0.416243\n", "Epoch: 25, Loss: 0.400543\n", "Epoch: 26, Loss: 0.096858\n", "Epoch: 27, Loss: 0.447385\n", "Epoch: 28, Loss: 0.257410\n", "Epoch: 29, Loss: 0.229853\n", "Epoch: 30, Loss: 0.234711\n", "Epoch: 31, Loss: 0.112913\n", "Epoch: 32, Loss: 0.182169\n", "Epoch: 33, Loss: 0.256886\n", "Epoch: 34, Loss: 0.417403\n", "Epoch: 35, Loss: 0.250923\n", "Epoch: 36, Loss: 0.050652\n", "Epoch: 37, Loss: 0.269337\n", "Epoch: 38, Loss: 0.281109\n", "Epoch: 39, Loss: 0.143530\n", "Epoch: 40, Loss: 0.517404\n", "Epoch: 41, Loss: 0.059098\n", "Epoch: 42, Loss: 0.236813\n", "Epoch: 43, Loss: 0.089123\n", "Epoch: 44, Loss: 0.037688\n", "Epoch: 45, Loss: 0.081337\n", "Epoch: 46, Loss: 0.025698\n", "Epoch: 47, Loss: 0.189287\n", "Epoch: 48, Loss: 0.029990\n", "Epoch: 49, Loss: 0.104520\n", "Epoch: 50, Loss: 0.006282\n", "Epoch: 51, Loss: 0.009865\n", "Epoch: 52, Loss: 0.010996\n", "Epoch: 53, Loss: 0.032748\n", "Epoch: 54, Loss: 0.011533\n", "Epoch: 55, Loss: 0.031718\n", "Epoch: 56, Loss: 0.051569\n", "Epoch: 57, Loss: 0.000723\n", "Epoch: 58, Loss: 0.025385\n", "Epoch: 59, Loss: 0.002961\n", "Epoch: 60, Loss: 0.016891\n", "Epoch: 61, Loss: 0.007197\n", "Epoch: 62, Loss: 0.005152\n", "Epoch: 63, Loss: 0.006943\n", "Epoch: 64, Loss: 0.029282\n", "Epoch: 65, Loss: 0.016945\n", "Epoch: 66, Loss: 0.001008\n", "Epoch: 67, Loss: 0.003609\n", "Epoch: 68, Loss: 0.041969\n", "Epoch: 69, Loss: 0.011577\n", "Epoch: 70, Loss: 0.002608\n", "Epoch: 71, Loss: 0.002957\n", "Epoch: 72, Loss: 0.004216\n", "Epoch: 73, Loss: 0.001568\n", "Epoch: 74, Loss: 0.004731\n", "Epoch: 75, Loss: 0.002923\n", "Epoch: 76, Loss: 0.010029\n", "Epoch: 77, Loss: 0.007263\n", "Epoch: 78, Loss: 0.009012\n", "Epoch: 79, Loss: 0.002666\n", "Epoch: 80, Loss: 0.000942\n", "Epoch: 81, Loss: 0.004340\n", "Epoch: 82, Loss: 0.004340\n", "Epoch: 83, Loss: 0.004355\n", "Epoch: 84, Loss: 0.000616\n", "Epoch: 85, Loss: 0.002305\n", "Epoch: 86, Loss: 0.000936\n", "Epoch: 87, Loss: 0.002565\n", "Epoch: 88, Loss: 0.008841\n", "Epoch: 89, Loss: 0.004678\n", "Epoch: 90, Loss: 0.019521\n", "Epoch: 91, Loss: 0.002626\n", "Epoch: 92, Loss: 0.356031\n", "Epoch: 93, Loss: 0.004904\n", "Epoch: 94, Loss: 0.000408\n", "Epoch: 95, Loss: 0.002940\n", "Epoch: 96, Loss: 0.000598\n", "Epoch: 97, Loss: 0.000866\n", "Epoch: 98, Loss: 0.000400\n", "Epoch: 99, Loss: 0.001557\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "\n", "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)\n", "\n", "model = nn.Sequential(\n", " nn.Linear(3072, 1024),\n", " nn.Tanh(),\n", " nn.Linear(1024, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 128),\n", " nn.Tanh(),\n", " nn.Linear(128, 2))\n", "\n", "learning_rate = 1e-2\n", "\n", "optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n", "\n", "loss_fn = nn.CrossEntropyLoss()\n", "\n", "nepochs = 100\n", "\n", "for epoch in range(nepochs):\n", " for imgs, labels in train_loader:\n", " outputs = model(imgs.view(imgs.shape[0], -1))\n", " loss = loss_fn(outputs, labels)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " print(\"Epoch: %d, Loss: %f\" % (epoch, float(loss)))" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.998100\n" ] } ], "source": [ "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=False)\n", "\n", "correct = 0\n", "total = 0\n", "\n", "with torch.no_grad():\n", " for imgs, labels in train_loader:\n", " outputs = model(imgs.view(imgs.shape[0], -1))\n", " _, predicted = torch.max(outputs, dim=1)\n", " total += labels.shape[0]\n", " correct += int((predicted == labels).sum())\n", " \n", "print(\"Accuracy: %f\" % (correct / total))" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.802000\n" ] } ], "source": [ "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)\n", "\n", "correct = 0\n", "total = 0\n", "\n", "with torch.no_grad():\n", " for imgs, labels in val_loader:\n", " outputs = model(imgs.view(imgs.shape[0], -1))\n", " _, predicted = torch.max(outputs, dim=1)\n", " total += labels.shape[0]\n", " correct += int((predicted == labels).sum())\n", " \n", "print(\"Accuracy: %f\" % (correct / total))" ] }, { "cell_type": "code", "execution_count": 108, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3737474" ] }, "execution_count": 108, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum([p.numel() for p in model.parameters()])" ] }, { "cell_type": "code", "execution_count": 109, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3737474" ] }, "execution_count": 109, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum([p.numel() for p in model.parameters() if p.requires_grad == True])" ] }, { "cell_type": "code", "execution_count": 111, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1574402" ] }, "execution_count": 111, "metadata": {}, "output_type": "execute_result" } ], "source": [ "first_model = nn.Sequential(\n", " nn.Linear(3072, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 2),\n", " nn.LogSoftmax(dim=1))\n", "\n", "sum([p.numel() for p in first_model.parameters()])" ] }, { "cell_type": "code", "execution_count": 112, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1573376" ] }, "execution_count": 112, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum([p.numel() for p in nn.Linear(3072, 512).parameters()])" ] }, { "cell_type": "code", "execution_count": 113, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3146752" ] }, "execution_count": 113, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum([p.numel() for p in nn.Linear(3072, 1024).parameters()])" ] }, { "cell_type": "code", "execution_count": 114, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1024, 3072]), torch.Size([1024]))" ] }, "execution_count": 114, "metadata": {}, "output_type": "execute_result" } ], "source": [ "linear = nn.Linear(3072, 1024)\n", "\n", "linear.weight.shape, linear.bias.shape" ] }, { "cell_type": "code", "execution_count": 115, "metadata": {}, "outputs": [], "source": [ "conv = nn.Conv2d(3, 16, kernel_size=3)" ] }, { "cell_type": "code", "execution_count": 116, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([16, 3, 3, 3])" ] }, "execution_count": 116, "metadata": {}, "output_type": "execute_result" } ], "source": [ "conv.weight.shape" ] }, { "cell_type": "code", "execution_count": 117, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([16])" ] }, "execution_count": 117, "metadata": {}, "output_type": "execute_result" } ], "source": [ "conv.bias.shape" ] }, { "cell_type": "code", "execution_count": 118, "metadata": {}, "outputs": [], "source": [ "img, _ = cifar2[0]\n", "\n", "output = conv(img.unsqueeze(0))" ] }, { "cell_type": "code", "execution_count": 122, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1, 3, 32, 32]), torch.Size([1, 16, 30, 30]))" ] }, "execution_count": 122, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img.unsqueeze(0).shape, output.shape" ] }, { "cell_type": "code", "execution_count": 130, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAHL5JREFUeJztnXtsnOWVxp8zN4/vduzEdu4hF5IAJSBD6YIgvVDRblVa7RaVP7qsVDXVblEXqdJuhVYq/6yEdlvaSrutFLaoQSq9aGkXpFIWmtKmLW2JCTQJJORO4thxHN9vY8/l7B+e7DphzplJnMwY3ucnRbHnzPe9Z97ve+ab8fOd84qqghASHpFKJ0AIqQwUPyGBQvETEigUPyGBQvETEigUPyGBQvETEigUPyGBQvETEiix+WwsIvcA+DaAKID/VNVHvedH62s11tJcOJgTf6ysE8vZMY062yWdnQJIxOx4BPadkVm1X8tMxp9yTdvvx5Kw86mJp+19FrmJcyqdMGPRiD25Cv+YZTPOtcXJKRJzDmgRcs6YNclpe0znpYxPVbljevnmss4ciDMJRfSAbOF4ZnAQ2YmJIhvPctniF5EogP8AcDeAbgC7ReRZVX3THKylGe3//OXC+0v5H0ISw3Y8OmW/1nSDPcHVm4bdMVc02fFk1BbbyEy1GXu733jzy5Ppt7dNdkyYsfd19JixnPNmBAD7epeasbpqWzAZ78QGMDJUa8bUOHkBoKZpyoxJkdN6YtCev1s2HjdjiYj9xvr7/evdMWtbJ+18hu18InF7zNyUL83oWOGr2unHvuVud8H4JT/zndwK4IiqHlPVGQA/AnDvPPZHCCkj8xH/MgCn5vzenX+MEPIuYD7iL/QB7B2fsUVkm4h0iUhXdsz+2EoIKS/zEX83gBVzfl8O4B1fPFV1u6p2qmpntN7+DkgIKS/zEf9uAOtFZI2IJAB8FsCzVyYtQsjV5rL/2q+qGRF5EMD/YNbqe0JV3/C2qa9O4SM3FjYDdh681h2v4VXbbplucvKM2X8eHhu1/xILAKcdKyadtT3EiZGkGWtv9x2Gto4+M7Y4OW7GqiIZM/Zy72p3zOqqGTM2nbFf53h3g7tfabL329Y2asZSM3EzNjzof3qsO2TblvsabVcj4liaMuNfI1PH681YleO6TS+391l1xpfm9GLDKYiU3pxnXj6/qj4H4Ln57IMQUhl4hx8hgULxExIoFD8hgULxExIoFD8hgULxExIo87L6LpWIqOlH64xTewvgXKftw6681vbGT73ZbsZq99p+PACMt9r3FuSqbD9VonasL9rojpnssP36pTUjZuzw6GIz5t2TAADptB2f6asxY5rwS2/vWHvUjL3eZ5eBTE3a8y4T/inbcNyulDtXX2fGEuecKkP/dhBMXGtXPuZa7O2i5+z7GdJ1vl8vNcZ5cgmXc175CQkUip+QQKH4CQkUip+QQKH4CQkUip+QQCmr1TcyWY2fv/a+grHmdrvEEwDa68fMWO+oXVoambm85p4AUNdtb5scsLc9c6cd+9CGQ+6Yh0dsy+4X+643YxGvs2+tbUUBwPSEXQYba7G3vWnlKTMGAFVOU8zJI7bl6dmoG6/3x3yrsc3er2Nprv65XX587n1+GXGk2WngOWbbyepcenM1vo0qI4ZN6DRGvRhe+QkJFIqfkECh+AkJFIqfkECh+AkJFIqfkEApq9UnaTG7kg6p3wnWw+v2WrPB7pY77tgwAKAROy6OpdL0hj2tL+Wuc8eURttyEqczq7eGnXgLQgLYusm2HwembZtrJuufPjuP2h2ZI976gc7r/Kv2Pe6YtUtta3L3+Boz9vwHbzNjkxt8qzQ+5VilJ+xzKOes/6lp37JLni1sW4pdFPoOeOUnJFAofkICheInJFAofkICheInJFAofkICpaxWnyYU0ysKW1nVRxzfA8DIZLMZE6cA6o479pmxNdecc8fs2WivAPr8oc32hmfs15IY8N9vdcipAltrV4+tWTJgxlbWDbljVkdte9Gz+nLw7SjPmlSnyalazSkBvDjozDuAPbtse7Gmx1k0c5VzEmX8Y5Y9Y3f4jDmbZhrt11lscdDplsL56iUoel7iF5ETAMYAZAFkVLVzPvsjhJSPK3Hl/6Cq+pdQQsiCg9/5CQmU+YpfAbwgIq+KyLZCTxCRbSLSJSJd2bGJeQ5HCLlSzPdj/+2q2iMiSwC8KCIHVXXX3Ceo6nYA2wGgas1y/yZzQkjZmNeVX1V78v+fBfAzALdeiaQIIVefyxa/iNSKSP35nwF8FMD+K5UYIeTqMp+P/W0AfiaztaQxAE+p6vPeBpIWxPsKlz9avuV5cvW2J1rVY5f0vvTbG+xYkUanDettf/zvb/yNGVt1i21+7Oj9C3fMfW+uNGMRx0/uTtjdcI/1tbpj6ml7v9nmtBlrW2qXSwPAijZ7/vqS9WYsFrO7/qYy9rEGgOU77XsWkr3jZmzoRvs+kkjav0YObbDjObvaF3DudZBm+3UAgA4bOy69ee/li19VjwG48XK3J4RUFlp9hAQKxU9IoFD8hAQKxU9IoFD8hARKWUt6kQOiU4W9iPQS21IC4C5AmKmxLZNsg20bxQf9lz/61iIz9u+n7jZjVUvs0ttPrrNLjAGg7WZ7QdJf7rE7/6betq2zSKaI/+N1BZ62F7fs67VLngGgocW+nXt6yrbsbttwwoxF4N8kOvXLV83Y5MdvMWOZKmeOkv78Zeqd8y/p5Jtz9tvvl7jXnCl83Y74DuGFzy39qYSQ9xIUPyGBQvETEigUPyGBQvETEigUPyGBUl6rL+JYH8XafDgVUJHltrXmLQiZLjJmZMq2uaqMhRIBIDdYZ8b+0GAvFgkAqYxzSKrsykev460kbLsTAHJZ+xoQj9vb1iR9XynuVOeNTtqv869bd5uxfzt2jztmzaYOMzZwnW0vZp3qO7cyD0Cmzn6dsTFnbkftOag97Y8ZmSl88kaKOOYXPLf0pxJC3ktQ/IQECsVPSKBQ/IQECsVPSKBQ/IQESlmtPskAyf7C1ptG/MaM2Rrb5kqLvW1kxNlvwm8aqk4TxVTCnrpIve23dJ+xG0UCQNVRe6HOaKOdb7bJtvrUWTgUAHKNjlW12J6DVU3+AqA3NZ0yY92t9jz8ZU3KjD20x7byAGBlh9PA8y67sWrOsX0Hu/3qxcSgbfu27bbnNuJ4zdVv29WdADC9tLCd7O3zHc8t+ZmEkPcUFD8hgULxExIoFD8hgULxExIoFD8hgULxExIoRX1+EXkCwCcAnFXV6/OPLQLwYwCrAZwAcJ+q+qYvgFwVMH5NYd8z2Wd7pQCQHLDj6Vo7lqm1vXGJ+l1ZZciu5axZafuwy5vsBSyPvLbCHdNbaFGcLryRMedQFitdTtnXgKkJ+x6B2BK/VPjEVIsZe3/DMTP25Ki9sGjbbv/ejMk2+76OrUsPm7Epp6b3t1n/3JxstOeop8q+b6PxkH08o6lad0zrnNdI6St1lnLl/z6Ai4uovwpgp6quB7Az/zsh5F1EUfGr6i4Agxc9fC+AHfmfdwD41BXOixBylbnc7/xtqtoLAPn/l1hPFJFtItIlIl3ZcXt9dEJIebnqf/BT1e2q2qmqndE6u70VIaS8XK74+0SkAwDy/5+9cikRQsrB5Yr/WQAP5H9+AMAzVyYdQki5KMXq+yGArQBaRaQbwNcAPArgJyLyeQAnAXymlMEkAyQGC7/fxP0KRreDamzKjqXt9Sshad8WySVtWyketW2u7mG7BDQ2UboVczHqdDDWmNPduMjrTIza14BJx0Y9OmhbcgAwNmHbXLuy68xYXb1d0lsb819LutaO/9fuTjMWb5w2Y+3N/sn5dxt2mbG+6xrN2DMbbjBjx/oa3DHjdYXnKPO6b4XOpaj4VfV+I/ThkkchhCw4eIcfIYFC8RMSKBQ/IYFC8RMSKBQ/IYFS3oU6FYhMF7ZiUi1+6Zk6TXi9LT2bK+N0rQXgLg46ftDuPhuxG+kikvWtqkytPaZX1Zfstd/HY7ZzBgDIOs196w7bE1/1a7+rbf24U1HpOFLd9zk7Xepfr9QJV5+0X0uqzd6we6DaHfPp2M1mrLnKXkT2muYBM5Zruric5kLW1Bbe9qkqx/e+CF75CQkUip+QQKH4CQkUip+QQKH4CQkUip+QQCmr1adRYLrF8HiKvA1Fp2ybK+s16fSaXk76g0YNWxIAYpN2rMpxaTK+a4RUh20/RsecJqYNtkWYutauWAOASMyZv2474cSwb1s2HrZtJ43Zcx/pqbFj9jqcAICMvSlit9g9ZnXQbpgZHfRlcuRouz2ms21kzYQZu2Fpjzvm3uFlBR+fyvoL3l4wfsnPJIS8p6D4CQkUip+QQKH4CQkUip+QQKH4CQkUip+QQCmrzx+fBNpeKRybcbquAkDU8XerxmyPO9ln17OmG3xPtK/Tbhlc3W+PWdNv++bnbvAXffTwymBzVU73Xqc0GQCg9tx33HjGjE1t9ufv8A122XPtSWcexM63ts8vw+75ZNqM/aPTZffxo7ebsWVrR90xD/UtNmORU3b76IxzWFqq7HsAAODYUOFFUDO50q/nvPITEigUPyGBQvETEigUPyGBQvETEigUPyGBUspCnU8A+ASAs6p6ff6xRwB8AUB//mkPq+pzxfYVncyg6bVzBWNa7azECSAyPG7GNOFYTmf6zVCs3bZoAKBhsR3PRW17bGiDbWOl2p3WvgAiKef92HFDc1WODzjiW3JRp7T51FRhSwkAks7ilgDQuHLEjOVWOF2Vh+y63FSTf57ouH1Kf+fQnWbsjmXHzdjftPzeHfOPrWvN2AuLN5uxTQ22jdpQpOVyS23hrsCxSOkLdZZy5f8+gHsKPP5NVd2S/1dU+ISQhUVR8avqLgB+E3FCyLuO+Xznf1BE9orIEyJi38pFCFmQXK74vwtgLYAtAHoBfMN6oohsE5EuEemaydirlxBCystliV9V+1Q1q6o5AI8DuNV57nZV7VTVzkTMabBGCCkrlyV+EemY8+unAey/MukQQspFKVbfDwFsBdAqIt0AvgZgq4hswewamScAfLGUwTQWQaa1rmAsF/ffh6TeXk1ybJUdqzvdaMYm23zbaGyFndNMo12SlVlSpMWsQ84rbnRWoYxOOLGUXzHpdeGNdNtzNNXuW4ijy227SiJOSZszCcMb/QrFZJ99Sk+k7HPh91hjxt4aWeKO+cElh8zY6jr7b+Wba+wOvd0zi9wx71p8uODjB2O+/TqXouJX1fsLPPy9kkcghCxIeIcfIYFC8RMSKBQ/IYFC8RMSKBQ/IYFC8RMSKGXt3putimB4XeFVX7NJf9uZetv7nW6xvd9s3L4HYHSdP2Zkw5gZq4nbpbmp/U1mLFdkxmNrndJlx+LOHit8/wQAVA36Pr9z+wCm2uxBnaa/AIDEYXuFX3U6Cmc77Ln1VmsGgFzC3u/ijYXLyQFgYL9dvn16yi9d2VFdeMVcAMgssrsJ/3HJKjM2Nu4v53z3+oMFH5/R0rtD88pPSKBQ/IQECsVPSKBQ/IQECsVPSKBQ/IQESpmtPmBkfeGYFskkm7C7kuaabGtoatguSZ1p8TvpLqu3F0scmrCtmOaDtt00cINvVaX67f1Kjb1IpS6xLaWJhD+56thjniUnmSKLq6Zs20mcst3YkJ3v4j/73Wm9zsmLqu1OUoMr7WOd7vWb0CzZbcemWu3zb7zDLtvNFDk3Xzi0qeDjo6lfudvNhVd+QgKF4ickUCh+QgKF4ickUCh+QgKF4ickUMpq9WkUSNcb1pH4XVk1aVs8saRj9bXZLzHR7C+GGHFyqnqxwY4N25Zcxm4gCwCQafv9OH7atrFyTjFXpsnOB8BsD2aD5Bl7/rLV/jHzDulMk308E8P2HETS/pgJe21QHNqz0ozVrxs2Y6tu7nPHPDhqL9RZd9LONzZh253Vm+zqTgAY7TeqON32zxfCKz8hgULxExIoFD8hgULxExIoFD8hgULxExIopSzUuQLAkwDaAeQAbFfVb4vIIgA/BrAas4t13qeqQ0X3Zzg8kvHfhyLjdjyddSrE0nYsk/abHXYfbDNjG1/qN2OT6+yGj8lef3HL5IAdcxtmerHT/mGOZGw7asZ2NBEf922l2h7bzhtZax9P6xwBgGjKr+obXWfHNW6/zuEe+4U2VvuWsGdbTuTs15m0+4miqdauQASAsRGj+rN0p6+kK38GwFdUdROA2wB8SUQ2A/gqgJ2quh7AzvzvhJB3CUXFr6q9qron//MYgAMAlgG4F8CO/NN2APjU1UqSEHLluaTv/CKyGsBNAP4EoE1Ve4HZNwgA/iLmhJAFRcniF5E6AE8DeEhVRy9hu20i0iUiXdlxu1sKIaS8lCR+EYljVvg/UNWf5h/uE5GOfLwDwNlC26rqdlXtVNXOaF3tlciZEHIFKCp+EREA3wNwQFUfmxN6FsAD+Z8fAPDMlU+PEHK1KKWq73YAnwOwT0Rezz/2MIBHAfxERD4P4CSAz1ydFAkhV4Oi4lfV38F2Dz98SaOJ3SnWK28EgMSo4+VPOmWnSacst2bGHTP+srPQZNweM11r3z/Q9ordZRcAqne9acZkWbsZ69tq/721rscv6U3X2R8ARzY4C192ubtF42G7LHVsle2r151yvPpokYU66+3y7oZ9difdqdvsv0cNTvqLZiJmz1FqmX28Uyvs7eqd+wMAoPqtwivbRlIs6SWEFIHiJyRQKH5CAoXiJyRQKH5CAoXiJyRQytq9F1BopLC9ka3yt0y12LZIzinV9DrBTp6qd8fcfN9xM3Zka6sZm+mz86k96Zf0to9vMGOT7c6io022xVPzsn83dvc99oKRkrVtt0WvDbr7PflJe44yN4+Zsakp+7j0fcS3Sq9ZYZdaH0vbVmlbo2319fU0uWPeeP0JM1YXnzZjGbXPzd3HV7ljJoxTwdnlO+CVn5BAofgJCRSKn5BAofgJCRSKn5BAofgJCZTyWn0qkJnC7ze5Kr8rq1UNOBu0Q9OODdiyxm82nIjYFWLTg3al17XXdZuxa25z2vMC+MXm68zY4iW2tZb9w2IzlquzLUIAmFhhz/2SV+ztclW+bdn2EXsejh2zOyM3n7Xz2bTBtl8B4OFlz5mxX3VsNGO7BtabscnWIh2XY7b9uChhW4h3NbxlxnrG/RVduwcKe+PqN6S+AF75CQkUip+QQKH4CQkUip+QQKH4CQkUip+QQCmv1RdVoLGwLSLDvp0ik3bVWq7atoa+87Hvm7HtPXe6Yx7Yads/Xu/Pow227daa9BcuaW6xm14Oj9aYsWp7MwxcV2S9hCa78qz+bXtuM01+KeaJM7bFuPhl+9SbbLePdf9UnTvmvc9/2Q46bnJtu31c1rU4K2oC2HNquRmLRm2r+cFbf23G7mw74o754wHDCoz5lvlceOUnJFAofkICheInJFAofkICheInJFAofkICheInJFCK+vwisgLAkwDaMeuUblfVb4vIIwC+AOB8u9SHVdWupwQgEUW8urDPPzPp1yJG6u2yybYWuxPsU/3vN2N/7lrrjhmtdhb5HHK65f7SLvfd27zZHXNime3T1p2y36vTjv2drfIXb4w4C02mG+z7LwY2F7k3o9fe70yDndPkLZNmrHvA76Tb8qp9Ho2vsLebStkdg/eOFWktrfZrqXfOzS8fvc+MHX7dSRZArsEoN8+VvlBnKTf5ZAB8RVX3iEg9gFdF5MV87Juq+vWSRyOELBhKWaK7F0Bv/ucxETkAYNnVTowQcnW5pO/8IrIawE0A/pR/6EER2SsiT4hIs7HNNhHpEpGu7Kh/ayshpHyULH4RqQPwNICHVHUUwHcBrAWwBbOfDL5RaDtV3a6qnaraGW0oco85IaRslCR+EYljVvg/UNWfAoCq9qlqVlVzAB4HcOvVS5MQcqUpKn4REQDfA3BAVR+b83jHnKd9GsD+K58eIeRqUcpf+28H8DkA+0Tk9fxjDwO4X0S2YLZ37gkAXyy2I1VBetoYsohDkZuxLZy+7oJ/bgAADL1ml9fWDPuDZp2mt/Ex28aq6c+asVSLP+V1J+3349Z9dunt6bvsZCVbxOo7mTRjGrVfS/Pdve5+x0+3mLHsWTvf+AG7dHn62il3zOGN9nHJNtndmKND9nGJv+1bfTnnkFa322Me7VppxqpGighiZargw3IJJb2l/LX/dygsTdfTJ4QsbHiHHyGBQvETEigUPyGBQvETEigUPyGBUtbuvZFIDvUNha2a0axt7wBA3OkEG3M6+9Y4lWXqdFYFgKlWe7+xaXvbiXbbltQPjLhjjvbZd0GOr7Sr6OLL7fa9kVfsijUAaDhmxxKjdjXl8b3t7n6rHbsqaruWqD9p21WpftuWBPwKxtRyu+Vyrtq25OQt/9yEsxjs0G/sOaop7NYBAMbW2BYrAGxYVHiR2R7Hmr0YXvkJCRSKn5BAofgJCRSKn5BAofgJCRSKn5BAKavVl0tHMdbdUDiY9C2KdKtTkdVtW2DpOsf6afGtPnUsnLOd9na5ejvXxBvG689TPW3nO7PJrmj70JrDZuzXkXXumKkuO6d0rX2KRGwXEAAwtcp+QiphH+/0QdvOa3/F8QgBpOucBp5rbMsul7TtRa33K+Wyi+zjvaTLzme60b72xsf86/KJc4sKPj6TKV3SvPITEigUPyGBQvETEigUPyGBQvETEigUPyGBQvETEihl9fklC8RHCr/f6Lj/PpRucjrirrJLNaejjkc76i80WbPMLpPNOQsiZo7YJbQ1Z/yurNUDdr5yyO4i+/ztW8zYlpuOumMe/4DtU/c02Z2Rs0n/Pol4ve3J37thnxnbs8xepPJE21J3zIbD9nmUKFwFCwCYdu75aN446I557pS9eKhe5uVV/XVrkZ4sXOKul7BQJ6/8hAQKxU9IoFD8hAQKxU9IoFD8hAQKxU9IoIiqb9dc0cFE+gG8PeehVgDnypZAcZiPz0LLB1h4OVU6n1Wqaq9OO4eyiv8dg4t0qapTGV9emI/PQssHWHg5LbR8PPixn5BAofgJCZRKi397hce/GObjs9DyARZeTgstH5OKfucnhFSOSl/5CSEVoiLiF5F7ROQtETkiIl+tRA4X5XNCRPaJyOsi0lWhHJ4QkbMisn/OY4tE5EUROZz/3y6xK08+j4jI6fw8vS4iHy9jPitE5CUROSAib4jIP+Qfr8gcOflUbI4ulbJ/7BeRKIBDAO4G0A1gN4D7VfXNsiZyYU4nAHSqasX8WRG5E8A4gCdV9fr8Y/8KYFBVH82/STar6j9VMJ9HAIyr6tfLkcNF+XQA6FDVPSJSD+BVAJ8C8LeowBw5+dyHCs3RpVKJK/+tAI6o6jFVnQHwIwD3ViCPBYWq7gJwceH4vQB25H/egdmTq5L5VAxV7VXVPfmfxwAcALAMFZojJ593DZUQ/zIAp+b83o3KT5oCeEFEXhWRbRXOZS5tqtoLzJ5sAJZUOB8AeFBE9ua/FpTta8hcRGQ1gJsA/AkLYI4uygdYAHNUCpUQf6FWI5W2HG5X1ZsBfAzAl/Ifeck7+S6AtQC2AOgF8I1yJyAidQCeBvCQqo6We/wS8qn4HJVKJcTfDWBun6blAHoqkMf/oao9+f/PAvgZZr+aLAT68t8tz3/HPFvJZFS1T1WzqpoD8DjKPE8iEses0H6gqj/NP1yxOSqUT6Xn6FKohPh3A1gvImtEJAHgswCerUAeAAARqc3/wQYiUgvgowD2+1uVjWcBPJD/+QEAz1Qwl/PiOs+nUcZ5EhEB8D0AB1T1sTmhisyRlU8l5+hSqchNPnn741sAogCeUNV/KXsS/5/LNZi92gOzDU2fqkQ+IvJDAFsxWxXWB+BrAP4bwE8ArARwEsBnVLUsf4Qz8tmK2Y+zCuAEgC+e/75dhnzuAPBbAPsAnO9y+jBmv2eXfY6cfO5HheboUuEdfoQECu/wIyRQKH5CAoXiJyRQKH5CAoXiJyRQKH5CAoXiJyRQKH5CAuV/AdN4UxrDfFcUAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.imshow(output[0, 0].detach(), cmap='gray')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 131, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 16, 30, 30])" ] }, "execution_count": 131, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output.shape" ] }, { "cell_type": "code", "execution_count": 169, "metadata": {}, "outputs": [], "source": [ "conv = nn.Conv2d(3, 1, kernel_size=3, padding=1)" ] }, { "cell_type": "code", "execution_count": 170, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 1, 32, 32])" ] }, "execution_count": 170, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output = conv(img.unsqueeze(0))\n", "\n", "output.shape" ] }, { "cell_type": "code", "execution_count": 171, "metadata": {}, "outputs": [], "source": [ "with torch.no_grad():\n", " conv.bias.zero_()" ] }, { "cell_type": "code", "execution_count": 172, "metadata": {}, "outputs": [], "source": [ "with torch.no_grad():\n", " conv.weight.fill_(1.0 / 9.0)" ] }, { "cell_type": "code", "execution_count": 173, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAF79JREFUeJztnV+MXdV1xn8L4z/YGBvb2Bjj1BD5IVFUSDRCkaiiNGkjGlUikZooeYh4QHFUgdRI6QOiUkOlPiRVkygPVSqnoJAqhND8UVAVtUEoFcoLiUMJfwIFgt1gPLGNwdgQAni8+nCv1WFy1zd3zsyca7K/nzSaO2fdfc4++5xv7r37u2vtyEyMMe1xzqQ7YIyZDBa/MY1i8RvTKBa/MY1i8RvTKBa/MY1i8RvTKBa/MY1i8RvTKOcupnFEXAN8GVgB/Etmfk49f926dblx48aRMfVNw5mZmZHbzzmn/t+1YsWKMta13bnnjh4utT91XqdPn+4U65OIWNJY1/11Hcfq3lH76xrrej2rmGpTjdWJEyd45ZVX6oGcRWfxR8QK4J+APwUOAj+NiLsz8xdVm40bN3LDDTeMjP32t78tj/XSSy+N3L5mzZqyzQUXXFDG1q1bV8Y2bNhQxjZt2rTg/b3++utlrDovgN/85jdlbKlR//BWrlxZxtQ/vVWrVo3cvnr16k7HOnXqVBk7ceJEGavG+NVXXy3bVP8wAF577bUypq6ZilX3/iuvvFK2qV6I7rjjjrLNXBbztv8q4KnMfDozXwPuBK5dxP6MMT2yGPHvAJ6Z9ffB4TZjzJuAxYh/1OeK3/lAFBF7ImJfROx7+eWXF3E4Y8xSshjxHwR2zvr7UuDQ3Cdl5t7MnMrMKfXZ2BjTL4sR/0+B3RFxWUSsAj4G3L003TLGLDedZ/sz81RE3Aj8JwOr77bMfFS1iYhyllLN9Faz+uedd96C24C2a9TsfDW7ff7553c6VjUWoMdDWUDV8dSMvhpH9W6tmtGHelZfzfar8VBukJqdr2b71RiqfnS189R9VTkBysWoxn4hFvGifP7M/AHwg8XswxgzGfwNP2MaxeI3plEsfmMaxeI3plEsfmMaZVGz/V2obBmV0aWsqIquSTPKmjt27NjI7ZdddlnZpkoGAp0kor4NqWyvytJTdqSyttTYq3aVLaoSdNQ90DWx5+jRoyO3q6QZlRSm7g+VLKRi1f2o7tMumYBz8Su/MY1i8RvTKBa/MY1i8RvTKBa/MY3S62z/zMwMJ0+eHBlTSSLVTHXXBB2VCKJm2asZW1XOqjpf0P1Xs9GKtWvXjtyuEnuqNqATpNQMfOVkdF0VWs1iq5n0qh+qTdd7p8uMPtTOlHKsKmdkIePrV35jGsXiN6ZRLH5jGsXiN6ZRLH5jGsXiN6ZRerf6XnjhhZExZUVVdk2XxBLQySpdln568cUXyzbKelGJLKofqv/VOKrxUCjbSyUYVbaXus7KKlMWbBdLTPVDWZhdk3e6jJWyFdW9My5+5TemUSx+YxrF4jemUSx+YxrF4jemUSx+YxplUVZfRBwATgIzwKnMnFLPn5mZKW0ZZWtUVp/KOFM24IYNG8qYqrlXLTWlbCN1XspGUxZhl2xGZZWpGnhda+5V1qIajyNHjpSxZ555pozt37+/jFXHU/eHyqhU/Vd2nhrHCtXHKrYQC3ApfP4/zsznlmA/xpge8dt+YxplseJP4IcR8bOI2LMUHTLG9MNi3/ZfnZmHImIrcE9EPJ6Z981+wvCfwh7Qyz0bY/plUa/8mXlo+PsI8D3gqhHP2ZuZU5k5pdZmN8b0S2fxR8S6iFh/5jHwAeCRpeqYMWZ5Wczb/m3A94bWwrnAHZn5H6pBZpa2nbJJKrrYJ6Az5lQxy8rSU8VHV65cWcaULaPsJpXFVll9aqyUValsRZVdWJ23WqLs0KFDZezRRx/t1K76qHnhhReWbRTqmikbUMWq+1FdF2UDjkvnPWTm08AVi+6BMWYi2OozplEsfmMaxeI3plEsfmMaxeI3plF6LeCZmaXloSyg6stBKitOWWVq3bRqXUCobS9lh6nMQ2U5dikkCrU9tBQFHxdC1X9l6arsQmVvdrExlT2rYupaq36o+7G6j1XR1Sq2kOvsV35jGsXiN6ZRLH5jGsXiN6ZRLH5jGqX32f5qaSI1m1vN9qvZUDVbrmb71Wx01e75558v26gaBiqmZpxVanS1T1W38IILLuh0LDXzXTkSanzVNduyZUsZ27ZtWxm79NJLR25X56X6eOzYsU7tutQFVPfAUuBXfmMaxeI3plEsfmMaxeI3plEsfmMaxeI3plF6t/qq5Adl9VX14LokuIBOwFD2YZWAcfLkybJNl4QOgIsuuqiMKWuuqk2n2qhafIrKtoV6jJUtqpK7lFW5devWMrZr166R29W9c/z48TKm+qiSuJStW/VF7a+yKp3YY4yZF4vfmEax+I1pFIvfmEax+I1pFIvfmEaZ1+qLiNuAPweOZOY7hts2Ad8CdgEHgI9m5gvz7Utl9SkLRVlpFV0XBe1SO08tyaWsQ7XkkrKGNm7cWMbWr18/cnvX8VBWpbLtqnp8amktVcNP1VaszlnF1HXpYrHNF1OZpJXl26XuoroX5zLOK//XgGvmbLsJuDczdwP3Dv82xryJmFf8mXkfMPdf/LXA7cPHtwMfWuJ+GWOWma6f+bdl5jTA8Hf9FStjzFnJsn+9NyL2AHug++dOY8zS0/WV/3BEbAcY/j5SPTEz92bmVGZOLXdZImPM+HQV/93AdcPH1wHfX5ruGGP6Yhyr75vAe4EtEXEQ+CzwOeCuiLge+BXwkXEOlpll0UdlsVUZYiorTmX1dS38WfVdHUvZcps2bSpjVXYewNq1a8tYlaGnzlllVE5PT5exgwcPlrEqM05lzKlxVJmHKkuzstiUddglaxJgx44dZUxlEVaFP1WbqpDoQj5azyv+zPx4EXr/2Ecxxpx1+Bt+xjSKxW9Mo1j8xjSKxW9Mo1j8xjRKrwU8I6LMZFOZVJUtozLflLWlrCG13lpl9ansPHVeav25rgU8q+MpW1TZbypz7+jRo2WsytBTGZrqvBRdrrUaD4WyHJV1q9pVfVH36YEDB0ZuX+qsPmPM7yEWvzGNYvEb0ygWvzGNYvEb0ygWvzGN0rvVV2VuKbtsIfbFGbpkCUJt56mYykZTGXhq/bmua+tVNRO6Zswp+0rZZeq8K1RGmrJ11ThW46EKk7788stlTJ2zKripallU96q6F9X9PS5+5TemUSx+YxrF4jemUSx+YxrF4jemUXqd7VeoGfguSRjKPehKNRut6sGpWWo1A6xmjrvUIFQzx2q2/5JLLiljmzdvLmMqWahCzcCrc+7iEqjls6ol5eaLvfBCvWKdWo6uumbKoakS0BbiAviV35hGsfiNaRSL35hGsfiNaRSL35hGsfiNaZRxluu6Dfhz4EhmvmO47Rbgk8CZIm43Z+YP5tvXOeecUyZ8qESWyq5RySOqdp5aBklZOZWlpOrtKatP2TJdagmqfSq7VI2HskzXr19fxqo+qmumEmrUdVG2XZeahl2TzJRdrWoXVpaeGt/KJlYW8VzGeeX/GnDNiO1fyswrhz/zCt8Yc3Yxr/gz8z6gLuFqjHlTspjP/DdGxEMRcVtE1EuXGmPOSrqK/yvAW4ErgWngC9UTI2JPROyLiH3qc5sxpl86iT8zD2fmTGaeBr4KXCWeuzczpzJzaiFrhxtjlpdO4o+I7bP+/DDwyNJ0xxjTF+NYfd8E3gtsiYiDwGeB90bElUACB4BPjXOwVatW8Za3vGVkTGWPVZaHsgdVpp2y0Z577rkyVmWdqXc06qOOWgpLZbipbK8qe0z1Q1lUyjpStldlpSnL66WXXipjKqvvyJEjZawaD5Vlp1BWpeqjilX399atWxfcRt0bc5lX/Jn58RGbbx37CMaYsxJ/w8+YRrH4jWkUi9+YRrH4jWkUi9+YRum1gOd5553HFVdcMTK2ZcuWsl1lr6iMOWXJHDt2rIw98cQTZezJJ58cuf3EiRNlG2WxqcKZKitRxapsuq5ZccpyVHZZZfWpsVcWrMo8VJZpdd6qH2o8lIXctZBrdd7Kyq7OS2V8zsWv/MY0isVvTKNY/MY0isVvTKNY/MY0isVvTKP0avWtWbOG3bt3j4xV26EuqKgymLoWYdy/f38Zq+whZTWpPirLTvVxw4YNZayyD5VV9uKLL5YxleWo2lU2oCrSqcZKjYdaF7CyvlR2oVpzT42HsuZUrBoTZdtVsYVkK/qV35hGsfiNaRSL35hGsfiNaRSL35hG6XW2/9xzz2Xz5s0jY9u2bSvbVTXmVO05hZr5VjO9v/71r0duV7P9Xeu6qaW8VBLUhReOXkJB7U/Nlh8+fLiMqYSmKllFjf3GjRvLmLrWKlYdTyVVqVl2lSClrqdqV7kman+Vm7UQTfiV35hGsfiNaRSL35hGsfiNaRSL35hGsfiNaZRxluvaCXwduBg4DezNzC9HxCbgW8AuBkt2fTQza5/s//e34E5W9eBUfTlleSg7T9V2qywxtcyUsthUkovap0ouUTZghUpyOXr0aBlT/a8STFSdu4svvriMVctTgV4urUL1o7JL50NZc2qsKttOaaWLjuYyziv/KeAzmfk24N3ADRHxduAm4N7M3A3cO/zbGPMmYV7xZ+Z0Zj4wfHwSeAzYAVwL3D582u3Ah5ark8aYpWdBn/kjYhfwTuB+YFtmTsPgHwRQLylqjDnrGFv8EXE+8B3g05lZf6/zd9vtiYh9EbFPfdY2xvTLWOKPiJUMhP+NzPzucPPhiNg+jG8HRi6Snpl7M3MqM6e6TqQYY5aeecUfg2nFW4HHMvOLs0J3A9cNH18HfH/pu2eMWS7Gyeq7GvgE8HBEPDjcdjPwOeCuiLge+BXwkfl2lJmlBVfZeVDbRmpZJWWtqI8fKqOrqp2nlpnqWitOWX2qTltlfyobSo2j6r/K0KtQtpzK6quyQUHXx6v6qOonqneo6lhqjFV2ZGUHd8kIVW3mMq/4M/PHQGUqvn/sIxljzir8DT9jGsXiN6ZRLH5jGsXiN6ZRLH5jGqXXAp5QWy8qC6+yQlSbrhl/69atK2OXXHLJyO2rVq0q2yirTNleqiiostgqq1LZkSqmrKO1a9eWserclGWnsvp27txZxpQ1V1m+avkvtdSbOme1T1XAs7Ju1XXuYrPOxa/8xjSKxW9Mo1j8xjSKxW9Mo1j8xjSKxW9Mo/Rq9amsvi721YoVK8o2yqJSxQ9VuyoTTGWjqUKRVZYg6CKd1dpuUGdHKstRjWPX7LfqvNU5q/Xz1P2hshKrduqclfWpsi2V5avOrbJFlT2o+jEufuU3plEsfmMaxeI3plEsfmMaxeI3plF6n+2vZl9VDb8qpmrZnThRVxdXMTXLXrVTs8NqBli5BMqRUDPH1ay+mjnukqAD2smoYippRjkSBw8eLGNdEnHUdVGOj1rOTY3Vpk2byljlgKglyipNLGQZL7/yG9MoFr8xjWLxG9MoFr8xjWLxG9MoFr8xjTKv1RcRO4GvAxcDp4G9mfnliLgF+CRwdPjUmzPzB2pfp0+fLmvrKfutsoDUcldPPfVUGdu/f38Ze/bZZ8vY0aNHR25XiSXKelH1ApXdpGrFVRabGitllSk7r4tVqeonKptV1TRUY1VZbMqWU0u2VfUkQY9jVf8R4PLLLx+5XdmDS8E4Pv8p4DOZ+UBErAd+FhH3DGNfysx/XL7uGWOWi3HW6psGpoePT0bEY8CO5e6YMWZ5WdBn/ojYBbwTuH+46caIeCgibouIOrnbGHPWMbb4I+J84DvApzPzBPAV4K3AlQzeGXyhaLcnIvZFxD5VhMIY0y9jiT8iVjIQ/jcy87sAmXk4M2cy8zTwVeCqUW0zc29mTmXmlKriYozpl3nFH4Np21uBxzLzi7O2b5/1tA8Djyx994wxy8U4s/1XA58AHo6IB4fbbgY+HhFXAgkcAD41345Onz5dWnqHDx8u21WZVNPT02Wbxx9/vIypduqjSZVFqDLmVLaispSUzaOy8KradMoOUxaVquGnbMBqrKrls0CPfdc6g1VM9UNdT9VO9VFZnNU7YrW/anzVPTWXcWb7fwyMMm2lp2+MObvxN/yMaRSL35hGsfiNaRSL35hGsfiNaZReC3ieOnWK48ePl7GKKqPr0KFDC24DulCkysKrrC1lNVVZjKCLSKovRKlYtYyT6qPKVFN2kyokWlmLyvpUFpuyI1U/qvNW95vKMFX3lUKNf3Ufqz5W46uWNfudfYz9TGPM7xUWvzGNYvEb0ygWvzGNYvEb0ygWvzGN0qvVNzMzUxaSrCwqqAtkKitEFcdU7dT6f5W9otqoTDVlsSnLUVmElR25devWso3qf5djQV3cU2UrqnXwVB+VDVgVO1Xn1fWaqftKredYnXeX4q+qf3PxK78xjWLxG9MoFr8xjWLxG9MoFr8xjWLxG9MovVp9mdmp8GBlsSmrSWWjqbXuVMZflZHW1bJTGW5q/T+V/Vad9+bNm8s2auy7FvesYqrNli1bypjKjlRjVVl66rooO1JdM3XvqGy7yv5WY19pwlafMWZeLH5jGsXiN6ZRLH5jGsXiN6ZR5p3tj4g1wH3A6uHzv52Zn42Iy4A7gU3AA8AnMrPOvjhzwGIGU83OVzOzamZTzeirmm9q5riacVZJGwuZfZ2NShJRs8rV7LZarksluailwdQ4VuOv6g+uXr26jKnZeZWIUyWMqWumksIuuuiiMqbGSjkqlROgHI7K8Vnq2f5Xgfdl5hUMluO+JiLeDXwe+FJm7gZeAK4f+6jGmIkzr/hzwJl/nyuHPwm8D/j2cPvtwIeWpYfGmGVhrM/8EbFiuELvEeAe4JfA8cw88970ILBjebpojFkOxhJ/Zs5k5pXApcBVwNtGPW1U24jYExH7ImKfWt7YGNMvC5rtz8zjwH8B7wY2RsSZ2btLgZErD2Tm3sycyswpNZFijOmXecUfERdFxMbh4/OAPwEeA34E/MXwadcB31+uThpjlp5xEnu2A7dHxAoG/yzuysx/j4hfAHdGxN8D/w3cOs4BleVRUVkhyg5TiRSqD8pyrGxK9Y5GJbIolO2lbMxqTFQbdSxlAyqqBBiVlNTVnlUJXl2sVnUPqGutkn662MFd7gFlic5lXvFn5kPAO0dsf5rB539jzJsQf8PPmEax+I1pFIvfmEax+I1pFIvfmEaJLtZb54NFHAX+d/jnFuC53g5e4368EffjjbzZ+vEHmVmnHs6iV/G/4cAR+zJzaiIHdz/cD/fDb/uNaRWL35hGmaT4907w2LNxP96I+/FGfm/7MbHP/MaYyeK3/cY0ykTEHxHXRMT/RMRTEXHTJPow7MeBiHg4Ih6MiH09Hve2iDgSEY/M2rYpIu6JiCeHvy+cUD9uiYhnh2PyYER8sId+7IyIH0XEYxHxaET81XB7r2Mi+tHrmETEmoj4SUT8fNiPvxtuvywi7h+Ox7ciolvK5Rkys9cfYAWDMmCXA6uAnwNv77sfw74cALZM4LjvAd4FPDJr2z8ANw0f3wR8fkL9uAX4657HYzvwruHj9cATwNv7HhPRj17HBAjg/OHjlcD9DAro3AV8bLj9n4G/XMxxJvHKfxXwVGY+nYNS33cC106gHxMjM+8Dnp+z+VoGhVChp4KoRT96JzOnM/OB4eOTDIrF7KDnMRH96JUcsOxFcych/h3AM7P+nmTxzwR+GBE/i4g9E+rDGbZl5jQMbkJg6wT7cmNEPDT8WLDsHz9mExG7GNSPuJ8JjsmcfkDPY9JH0dxJiH9UqZFJWQ5XZ+a7gD8DboiI90yoH2cTXwHeymCNhmngC30dOCLOB74DfDozT/R13DH60fuY5CKK5o7LJMR/ENg56++y+Odyk5mHhr+PAN9jspWJDkfEdoDh7yOT6ERmHh7eeKeBr9LTmETESgaC+0Zmfne4ufcxGdWPSY3J8NgLLpo7LpMQ/0+B3cOZy1XAx4C7++5ERKyLiPVnHgMfAB7RrZaVuxkUQoUJFkQ9I7YhH6aHMYlB4blbgccy84uzQr2OSdWPvsekt6K5fc1gzpnN/CCDmdRfAn8zoT5czsBp+DnwaJ/9AL7J4O3j6wzeCV0PbAbuBZ4c/t40oX78K/Aw8BAD8W3voR9/xOAt7EPAg8OfD/Y9JqIfvY4J8IcMiuI+xOAfzd/Oumd/AjwF/BuwejHH8Tf8jGkUf8PPmEax+I1pFIvfmEax+I1pFIvfmEax+I1pFIvfmEax+I1plP8DMQi2q65RHfEAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "output = conv(img.unsqueeze(0))\n", "plt.imshow(output[0, 0].detach(), cmap='gray')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 174, "metadata": {}, "outputs": [], "source": [ "conv = nn.Conv2d(3, 1, kernel_size=3, padding=1)\n", "\n", "with torch.no_grad():\n", " conv.weight[:] = torch.tensor([[-1.0, 0.0, 1.0],\n", " [-1.0, 0.0, 1.0],\n", " [-1.0, 0.0, 1.0]])\n", " conv.bias.zero_()" ] }, { "cell_type": "code", "execution_count": 175, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAGLdJREFUeJztnV2snNV1hp+FweAfgjnYxsfYgHFIBImKYx2hSFQoTdqIRpFIpCZKLiIuUBxVQSpSeoGo1IDUi6RqEuWiSuUUFFKlIWl+FFShNgilQrkhMZS/YFoTY4zrf8DY4d/26sWMpcNh1nvGc3y+Mez3kY7OnL1mf9+ePd86M7PfedeOzMQY0x5njHsAxpjx4OQ3plGc/MY0ipPfmEZx8hvTKE5+YxrFyW9Mozj5jWkUJ78xjXLmXDpHxHXAt4EFwD9n5tfU/ZcsWZITExMDY2+88UbZb/HixQPbzzij/t91/PjxMhYRZUwds+Lo0aNlTD0uNY4zzxztqam+sanOpWKjfgO0mkd1rmPHjo0UU891NQ41v+pc8/GN2GpO1LnOOuusge0HDx7kyJEj9SRPY+Tkj4gFwD8CfwbsAn4bEfdk5pNVn4mJCW6++eaBseeee64818aNGwe2V/8UAF555ZUytnDhwjJ2zjnnlLHqQnr++efLPs8++2wZUxfg8uXLy5ii+kekHnN1IQG8+eabZUxdnIsWLRrYrv65vvzyy2Xs8OHDI8WWLl06sL16EQL4wx/+UMbUfIz6j+3ss88e2K5eVFasWDGw/fbbby/7zGQub/uvBp7OzO2Z+QZwN3D9HI5njOmQuST/RcD0l+td/TZjzDuAuST/oPc4b3sfGBGbImJLRGxRb+uMMd0yl+TfBayd9vcaYPfMO2Xm5sycysypJUuWzOF0xphTyVyS/7fA5RGxLiIWAp8D7jk1wzLGzDcjr/Zn5tGIuAn4T3pS352Z+TvVJyLKlWW1slmtiquV+QMHDpQx9fFj7dq1ZWzZsmUD219//fWyj2JUSUnNVbXirOZKrfarmHrc1WMbVZ5V6o1SilavXj2w/bLLLiv7vPbaa2VMKQsLFiwoY+pxV/Oo5rdSMU5Gqp6Tzp+Z9wL3zuUYxpjx4G/4GdMoTn5jGsXJb0yjOPmNaRQnvzGNMqfV/pPljDPOKM04SmI7cuTIwHZlslBuupdeeqmMKalk5cqVA9svvvjiss8LL7xQxpTp58UXXyxj73nPe8rYueeeO7C9koZAG4yU7KXGWEmO55133kjjUOzZs6eMVfLsqlWryj5Kztu/f38ZU1KfQl2rFaM4Xd9235M+qzHmXYGT35hGcfIb0yhOfmMaxclvTKN0utqv+MAHPlDG9u7dO7BdmXdU2SplIKnOBXV5J2WaUaiVb3XMquwT1Kv6yk6t5uPQoUMjxSrVQRmFlLKgVtJHUX3UONS5Rq0lqKgedzWHUD/PJ6M4+JXfmEZx8hvTKE5+YxrFyW9Mozj5jWkUJ78xjdKp1Hfs2LFSHnrve99b9lM7oVRUJhzQBpKnn366jO3atWtg+yWXXFL2UTKUkpvOP//8MjbKlmKqJqCSr1S9Q1VjrtpxSElRyiikzrVmzZoyVu1so8xdKqaeM/W8KPNOdcz5rnbtV35jGsXJb0yjOPmNaRQnvzGN4uQ3plGc/MY0ypykvojYARwBjgFHM3NK3f/o0aOl1KeknEoKUe42JQ9OTk6WMUU1RuUEVPKVGn9Vow301lWV81DJm0p+U/2UHFm5KpUT86mnnipjaozr168vY1WtvmqeQEt2ymmnXH1qi7XqOlDO1EqeVZLuTE6Fzv8nmXnwFBzHGNMhfttvTKPMNfkT+GVEPBQRm07FgIwx3TDXt/3XZObuiFgJ3BcRT2XmA9Pv0P+nsAl0vXljTLfM6ZU/M3f3f+8Hfg5cPeA+mzNzKjOn1CKWMaZbRk7+iFgSEeeeuA18HHjiVA3MGDO/zOVt/4XAz/uS2pnAv2bmf6gOx48fL2Uq5X6r5CYl51VbfAFMTEyUMeUGrOQh5fRSH3UOHqxFEhVTElAl9SgZSo1fOcuUJFY9N2q7q23btpWxyp0HcO2115axav6Vy27RokVlTElpyg2o5r+aYyWzVvnSidSXmduBq0btb4wZL5b6jGkUJ78xjeLkN6ZRnPzGNIqT35hG6XyvvkryUFJI5ehSfZR7TO2Dd+GFF5ax6nxKXlFSn5KGlOtMSZWVTKVko/nYf67qpx6Xil1xxRVlTM1x9djUNaBk4sOHD5cxVexUybPVczbqfpPD4ld+YxrFyW9Mozj5jWkUJ78xjeLkN6ZROl3tz8yylplavaxWPZcuXVr2eeGFF8qYqqunVoFfe+21ge1qJV2NUW3zpWq+KUPTKPXgRlUd1HNWjaOaQ9BzPzVVl4es6vQBbN26dWC7qglYbcsG2pik5lgZpJSRqGIU5WkmfuU3plGc/MY0ipPfmEZx8hvTKE5+YxrFyW9Mo3Qu9VW1x5SkVMlGauskJa2o+m0qVkl6SipTMuCaNWvKmKqrp0xLldSjtkN79dVXy5jaGkzNVWU+UhKmMlVt3LixjCmJsDLiqK3SlLFHzYd6ztT5qlqIqq5l9Xxa6jPGzIqT35hGcfIb0yhOfmMaxclvTKM4+Y1plFmlvoi4E/gksD8zP9hvmwB+BFwK7AA+m5m1Va5PZpauLrX1UyUPqT7nn39+GVPSnJKvqlpxShpS0ouSqFRMORYr2U7JRkrqUzXrlOxVjV9JXuvWrStjF1xwQRlT46+cdqM6MZVUqaRn5SKs5lE5IKt6gep5nskwr/zfA66b0XYLcH9mXg7c3//bGPMOYtbkz8wHgJkvNdcDd/Vv3wV86hSPyxgzz4z6mf/CzNwD0P9db21rjDktmfcFv4jYFBFbImKL+gxjjOmWUZN/X0RMAvR/l7WNMnNzZk5l5pRaxDLGdMuoyX8PcEP/9g3AL07NcIwxXTGM1PdD4CPA8ojYBXwV+Brw44i4EdgJfGaYkylXnypKWW25pLaZUu8ylINQyYBVocXly5eXfUZxc4GW2EbZqkk9rt27d5cxJSsqF1v1fK5YsaLso7YhUxLsoUOHylgl26lim8pduGzZsjKm5Dz1kbe6HpUc+cwzzwxsV+7Ntx1/tjtk5ueL0MeGPosx5rTD3/AzplGc/MY0ipPfmEZx8hvTKE5+Yxql0wKeEVFKLEoSq+RBJYcplESoHGKVq0+50ZQMqMavJDYl9VWSmJL6Dh48WMbUfKxcWX+ru5L6lHy1evXqMqZkNPXYKtQ1oMZ43nnnlTHlBnzuuefKWCVVKmdqtWegch3OxK/8xjSKk9+YRnHyG9MoTn5jGsXJb0yjOPmNaZTOpb5KRhnFqaaKFSopRzm61DGr2OLFi8s+ygWm3GhKYlNzNco8KgfkqIUuq6KUO3fuLPtceeWVZUzJio8++mgZq/Y1VI9ZyWWqkKgq4KnOV12rqvirkj6Hxa/8xjSKk9+YRnHyG9MoTn5jGsXJb0yjdLrar1A1zqqVTWUGUltJqVVUtbpdGXvUOFS9QFVvTdX3U4pEdUz1uFTNOjUO9bj37t07sH3btm1ln6uvvrqMVXM/2ziqlXulpiilSF2nyiCljlmZhVSNxEo9UM/X2+479D2NMe8qnPzGNIqT35hGcfIb0yhOfmMaxclvTKMMs13XncAngf2Z+cF+223AF4ETrolbM/PeIY5VbnmlzBSV5KEMNcrg8vLLL5cxNY5KLlMyzp49e8qYMvaMKvVVEpaSw5RBR0mEaou1ShJTEqyaDyWZqsdWXTvKGKOuKyUTq+da9Vu1atXAdnUNV/N7MsauYV75vwdcN6D9W5m5of8za+IbY04vZk3+zHwAqEvJGmPekczlM/9NEfFYRNwZEXWNYWPMacmoyf8dYD2wAdgDfKO6Y0RsiogtEbFFfaXSGNMtIyV/Zu7LzGOZeRz4LlB+KTszN2fmVGZOVYt9xpjuGSn5I2Jy2p+fBp44NcMxxnTFMFLfD4GPAMsjYhfwVeAjEbEBSGAH8KVhTrZw4UIuvvjigTElbVUxVYtPOeZefPHFkWIrVqwY2P7888+XfdS2W9XWWqClKOUsq2RAVUNOyUNKVlyyZEkZq+aqagfttFMyoHpHWY1fXR9q2zB1LnVMJUdW51MSciVXn4zUN2vyZ+bnBzTfMfQZjDGnJf6GnzGN4uQ3plGc/MY0ipPfmEZx8hvTKJ0W8Fy8eDEbNmwYGNu9e3fZr5JQlKyhHGfVFk4A27dvL2OV+03JP0rOU04vJQ0p11kll6lxKIlK9VPzX83VmjVryj5qHpXspWTRI0eODGxXLkG1JZcqdjoxMVHG1FxVUraSZ6utwU5mGy+/8hvTKE5+YxrFyW9Mozj5jWkUJ78xjeLkN6ZROpX6zjnnHC6//PKBsZ07d5b9KilKOQGV5KGce08++WQZq2Seq666quyjUHJTJeUAnH9+XTipkpT2799f9lEuwVEKmkK9t97y5cvLPtWedaClMiUDVrKocgmquVKuRLWfoCpcWjk/lXRYuWOV03UmfuU3plGc/MY0ipPfmEZx8hvTKE5+Yxql09X+BQsWlCu6o2yhVW3FBNokos6lDEZVPbhly5aVfZSJSNX+UwYYpQRUxh61Iq7q9CmDlFqBr54ztYI9OTlZxpTBSNX+q86n+qjHrGoyjjpXlfqkru9K8bGxxxgzK05+YxrFyW9Mozj5jWkUJ78xjeLkN6ZRhtmuay3wfWAVcBzYnJnfjogJ4EfApfS27PpsZtaOGXo166q6dUrKqWQqVQNPmU5GNYlUdelUvb1RpT4lvylDUyUPqflQEpXqp6SoKqaMPevXry9jSmJTc1wZe5RBR9VIVEYnZdRS/So5Ul2nVb6onJjJMK/8R4GvZOYVwIeBL0fElcAtwP2ZeTlwf/9vY8w7hFmTPzP3ZObD/dtHgK3ARcD1wF39u90FfGq+BmmMOfWc1Gf+iLgU+BDwIHBhZu6B3j8IYOWpHpwxZv4YOvkjYinwU+DmzDx8Ev02RcSWiNiiPuMaY7plqOSPiLPoJf4PMvNn/eZ9ETHZj08CA8ufZObmzJzKzCm1GYIxpltmTf7oLQXfAWzNzG9OC90D3NC/fQPwi1M/PGPMfDGMq+8a4AvA4xHxSL/tVuBrwI8j4kZgJ/CZYU6oJI+TRdVFU+dRUs773ve+MlZJfWqLL/VRR0llynVWbUGlUFs/qeOpunqqlmD13Kg+q1atKmPPPPNMGXvppZfKWPW41TWwbt26MrZ3794ytmvXrjKm6iRW8qeSvw8fHvzJW0nVM5k1+TPz10AlBH9s6DMZY04r/A0/YxrFyW9Mozj5jWkUJ78xjeLkN6ZROi3gmZmlFKFkr8rdpOQw5W5S2ypdcsklZawq1KkKgqqCipV0CLBv374ypqS5Si6rtjwDLUMpp51ynVWS2JIlS8o+qjCpGqOa/1GuN3XtqPErObWS5qB2cKrH1ZWrzxjzLsTJb0yjOPmNaRQnvzGN4uQ3plGc/MY0ymkj9SlJrJJllFyj3E3K8af2+KukqIULF5Z9lKy4cmVd/Eg5xFQxy4suumhge7UfHOhioUrqU067SppTBUGfffbZMqYes5LfqutqPvbjU8VJ1fVdFWTdv39giQygvoYt9RljZsXJb0yjOPmNaRQnvzGN4uQ3plE6X+1XK8sV1QqrOpaq76e2u1KrspU5Zu3atSONozIKga79p2KVGUSZgZShRtXwU+Oo5lGZgR566KEyprbCUrX/qvlQprBDhw6VMWXeUaqPqg1Zbff2yCOPDGyHevxqfmfiV35jGsXJb0yjOPmNaRQnvzGN4uQ3plGc/MY0yqxSX0SsBb4PrAKOA5sz89sRcRvwReBA/663Zua96ljHjx8vpRdltqnqlakaZ2oLJ1U7T8k1VR02VV9OGVmUSeT9739/GduxY0cZqww8an4nJyfLmDKKKONJ1U8ZjJRRSBm1RjH2qOtD1TtUMrGa48pwBbWcqqS+U1HDbxid/yjwlcx8OCLOBR6KiPv6sW9l5j8MfTZjzGnDMHv17QH29G8fiYitQP1vzBjzjuCkPvNHxKXAh4AH+003RcRjEXFnRNTbrxpjTjuGTv6IWAr8FLg5Mw8D3wHWAxvovTP4RtFvU0RsiYgt6vOeMaZbhkr+iDiLXuL/IDN/BpCZ+zLzWGYeB74LXD2ob2ZuzsypzJxSe7MbY7pl1uSP3pL0HcDWzPzmtPbpS8SfBp449cMzxswXw6z2XwN8AXg8Ik5oD7cCn4+IDUACO4AvzXagY8eOlY4pJZdV0ouS+pTsomrnLV68uIxVW2iprZiUC0yhnILKTbd9+/aB7UqiGsUVB9oBWVGND7TTbsWKFWVMyanVdTDqtaPqRlbuPNDyclXnUR2vesxKPn7bMWa7Q2b+Ghh0RKnpG2NOb/wNP2MaxclvTKM4+Y1pFCe/MY3i5DemUTot4KmkPlV4sHIqKalJbZOl5DwlyVTyipKGVOFM5fRSxT2VfFg57dRcqeMpCVbNVdWvkktBPy/q+VRjrJ4bJYnt3r27jKl+V1xxRRlTY6y2B1PXwOrVqwe2KylyJn7lN6ZRnPzGNIqT35hGcfIb0yhOfmMaxclvTKN0LvVVDjjlzKqKalZyB2j5TRWDVAUQqz3tDh48WPZRDjwl9SnpU7m9JiYmBraropSqmGUlQ6lzQS05jVJ8FLTspeSt6rpSsqIahyoWquRItddgJX+q+hfVdWqpzxgzK05+YxrFyW9Mozj5jWkUJ78xjeLkN6ZROpX6MrOU2ZSzrJIvlKyhimpW+5zN1q/aU00Vx3zllVdGio0q9V1wwQUD29UY1bkOHDhQxpScWjn+VGFSJZmqmHLMVRKhciQqGVDJxErWPZnCmieonkuFksxn4ld+YxrFyW9Mozj5jWkUJ78xjeLkN6ZRZl0ajIhzgAeAs/v3/0lmfjUi1gF3AxPAw8AXMrNeRqe34lytVKvV0GprpVFr8akVfWXAqKiMR6DNL0p1qGodgl5xrhQQtd2Vmiu1rZUyQVXPpzLoKBVj7969ZUwpAdX4K+UGYM2aNWVMMarCVK3QK2PPokWLBrYrVedt9x3iPq8DH83Mq+htx31dRHwY+Drwrcy8HHgRuHHosxpjxs6syZ89TrwcntX/SeCjwE/67XcBn5qXERpj5oWh3iNExIL+Dr37gfuA3wOHMvPEtzx2AbU53Rhz2jFU8mfmsczcAKwBrgYGFSgf+AEwIjZFxJaI2KI+ExljuuWkVvsz8xDwX8CHgWURcWKlYg0wcKeDzNycmVOZOaUqnRhjumXW5I+IFRGxrH97EfCnwFbgV8Bf9O92A/CL+RqkMebUM4wLYBK4KyIW0Ptn8ePM/PeIeBK4OyL+Dvhv4I5hTqjkoYpK2lKSlzI4KEOQOma15ZWSypSEqWQZVTtPbQFWmW1GqXMH2jSjnsvKLFRJVKDlN/WY1Rir2oXKlLRy5cqTPh7Aq6++WsZGqRs5iuysTFozmTX5M/Mx4EMD2rfT+/xvjHkH4m/4GdMoTn5jGsXJb0yjOPmNaRQnvzGNEqNIbyOfLOIA8Gz/z+VAbcfqDo/jrXgcb+WdNo5LMrO2cE6j0+R/y4kjtmTm1FhO7nF4HB6H3/Yb0ypOfmMaZZzJv3mM556Ox/FWPI638q4dx9g+8xtjxovf9hvTKGNJ/oi4LiL+JyKejohbxjGG/jh2RMTjEfFIRGzp8Lx3RsT+iHhiWttERNwXEdv6v+vqjfM7jtsi4v/6c/JIRHyig3GsjYhfRcTWiPhdRPxVv73TORHj6HROIuKciPhNRDzaH8ft/fZ1EfFgfz5+FBF15dhhyMxOf4AF9MqAXQYsBB4Frux6HP2x7ACWj+G81wIbgSemtf09cEv/9i3A18c0jtuAv+54PiaBjf3b5wL/C1zZ9ZyIcXQ6J0AAS/u3zwIepFdA58fA5/rt/wT85VzOM45X/quBpzNze/ZKfd8NXD+GcYyNzHwAmGnYv55eIVToqCBqMY7Oycw9mflw//YResViLqLjORHj6JTsMe9Fc8eR/BcBz037e5zFPxP4ZUQ8FBGbxjSGE1yYmXugdxECdUWJ+eemiHis/7Fg3j9+TCciLqVXP+JBxjgnM8YBHc9JF0Vzx5H8g0rbjEtyuCYzNwJ/Dnw5Iq4d0zhOJ74DrKe3R8Me4BtdnTgilgI/BW7OzLFVex0wjs7nJOdQNHdYxpH8u4Dpm7SXxT/nm8zc3f+9H/g5461MtC8iJgH6v/ePYxCZua9/4R0HvktHcxIRZ9FLuB9k5s/6zZ3PyaBxjGtO+uc+6aK5wzKO5P8tcHl/5XIh8Dngnq4HERFLIuLcE7eBjwNP6F7zyj30CqHCGAuinki2Pp+mgzmJXqHDO4CtmfnNaaFO56QaR9dz0lnR3K5WMGesZn6C3krq74G/GdMYLqOnNDwK/K7LcQA/pPf28U1674RuBC4A7ge29X9PjGkc/wI8DjxGL/kmOxjHH9N7C/sY8Ej/5xNdz4kYR6dzAvwRvaK4j9H7R/O3067Z3wBPA/8GnD2X8/gbfsY0ir/hZ0yjOPmNaRQnvzGN4uQ3plGc/MY0ipPfmEZx8hvTKE5+Yxrl/wHo7ZuXD1PlgwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "output = conv(img.unsqueeze(0))\n", "plt.imshow(output[0, 0].detach(), cmap='gray')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 176, "metadata": {}, "outputs": [], "source": [ "pool = nn.MaxPool2d(2)" ] }, { "cell_type": "code", "execution_count": 177, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 3, 16, 16])" ] }, "execution_count": 177, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output = pool(img.unsqueeze(0))\n", "\n", "output.shape" ] }, { "cell_type": "code", "execution_count": 178, "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "ellipsis is not a Module subclass", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTanh\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMaxPool2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m ...)\n\u001b[0m", "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 52\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_module\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 53\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_get_item_by_idx\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0miterator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36madd_module\u001b[0;34m(self, name, module)\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mModule\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 170\u001b[0m raise TypeError(\"{} is not a Module subclass\".format(\n\u001b[0;32m--> 171\u001b[0;31m torch.typename(module)))\n\u001b[0m\u001b[1;32m 172\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_six\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstring_classes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 173\u001b[0m raise TypeError(\"module name should be a string. Got {}\".format(\n", "\u001b[0;31mTypeError\u001b[0m: ellipsis is not a Module subclass" ] } ], "source": [ "model = nn.Sequential(\n", " nn.Conv2d(3, 32, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " nn.Conv2d(32, 16, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " nn.Conv2d(16, 8, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " ...)" ] }, { "cell_type": "code", "execution_count": 179, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " nn.Conv2d(3, 32, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " nn.Conv2d(32, 16, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " nn.Conv2d(16, 8, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " # WARNING: something missing here\n", " nn.Linear(128, 32),\n", " nn.Tanh(),\n", " nn.Linear(32, 2))" ] }, { "cell_type": "code", "execution_count": 180, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "10874" ] }, "execution_count": 180, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum([p.numel() for p in model.parameters()])" ] }, { "cell_type": "code", "execution_count": 181, "metadata": {}, "outputs": [ { "ename": "RuntimeError", "evalue": "size mismatch, m1: [32 x 4], m2: [128 x 32] at /Users/lantiga/Desktop/invariant-ai/pytorch/aten/src/TH/generic/THTensorMath.cpp:932", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 475\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 476\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 477\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 478\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 479\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_modules\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 91\u001b[0;31m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 92\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 475\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 476\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 477\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 478\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 479\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 62\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mlinear\u001b[0;34m(input, weight, bias)\u001b[0m\n\u001b[1;32m 1065\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maddmm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1066\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1067\u001b[0;31m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1068\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mbias\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1069\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mRuntimeError\u001b[0m: size mismatch, m1: [32 x 4], m2: [128 x 32] at /Users/lantiga/Desktop/invariant-ai/pytorch/aten/src/TH/generic/THTensorMath.cpp:932" ] } ], "source": [ "model(img.unsqueeze(0))" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [], "source": [ "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)\n", " self.act1 = nn.Tanh()\n", " self.pool1 = nn.MaxPool2d(2)\n", " self.conv2 = nn.Conv2d(32, 16, kernel_size=3, padding=1)\n", " self.act2 = nn.Tanh()\n", " self.pool2 = nn.MaxPool2d(2)\n", " self.fc1 = nn.Linear(8 * 8 * 8, 32)\n", " self.act4 = nn.Tanh()\n", " self.fc2 = nn.Linear(32, 2)\n", "\n", " def forward(self, x):\n", " out = self.pool1(self.act1(self.conv1(x)))\n", " out = self.pool2(self.act2(self.conv2(out)))\n", " out = out.view(-1, 8 * 8 * 8)\n", " out = self.act4(self.fc1(out))\n", " out = self.fc2(out)\n", " return out" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "22002" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = Net()\n", "\n", "sum([p.numel() for p in model.parameters()])" ] }, { "cell_type": "code", "execution_count": 187, "metadata": {}, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)\n", " self.conv2 = nn.Conv2d(32, 16, kernel_size=3, padding=1)\n", " self.conv3 = nn.Conv2d(16, 8, kernel_size=3, padding=1)\n", " self.fc1 = nn.Linear(128, 32)\n", " self.fc2 = nn.Linear(32, 2)\n", " \n", " def forward(self, x):\n", " out = F.max_pool2d(torch.tanh(self.conv1(x)), 2)\n", " out = F.max_pool2d(torch.tanh(self.conv2(out)), 2)\n", " out = F.max_pool2d(torch.tanh(self.conv3(out)), 2)\n", " out = out.view(-1, 8 * 4 * 4)\n", " out = torch.tanh(self.fc1(out))\n", " out = self.fc2(out)\n", " return out" ] }, { "cell_type": "code", "execution_count": 188, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.0201, -0.1658]], grad_fn=)" ] }, "execution_count": 188, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = Net()\n", "model(img.unsqueeze(0))" ] }, { "cell_type": "code", "execution_count": 275, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0, Loss: 0.446465\n", "Epoch: 1, Loss: 0.654896\n", "Epoch: 2, Loss: 0.339780\n", "Epoch: 3, Loss: 0.460466\n", "Epoch: 4, Loss: 0.454873\n", "Epoch: 5, Loss: 0.168036\n", "Epoch: 6, Loss: 0.297592\n", "Epoch: 7, Loss: 0.384782\n", "Epoch: 8, Loss: 0.645598\n", "Epoch: 9, Loss: 0.159519\n", "Epoch: 10, Loss: 0.217299\n", "Epoch: 11, Loss: 0.252308\n", "Epoch: 12, Loss: 0.549105\n", "Epoch: 13, Loss: 0.196273\n", "Epoch: 14, Loss: 0.506769\n", "Epoch: 15, Loss: 0.198482\n", "Epoch: 16, Loss: 0.173888\n", "Epoch: 17, Loss: 0.315105\n", "Epoch: 18, Loss: 0.138680\n", "Epoch: 19, Loss: 0.302249\n", "Epoch: 20, Loss: 0.347972\n", "Epoch: 21, Loss: 0.167680\n", "Epoch: 22, Loss: 0.134659\n", "Epoch: 23, Loss: 0.583311\n", "Epoch: 24, Loss: 0.207660\n", "Epoch: 25, Loss: 0.121527\n", "Epoch: 26, Loss: 0.284999\n", "Epoch: 27, Loss: 0.320392\n", "Epoch: 28, Loss: 0.417843\n", "Epoch: 29, Loss: 0.404884\n", "Epoch: 30, Loss: 0.236999\n", "Epoch: 31, Loss: 0.426667\n", "Epoch: 32, Loss: 0.362905\n", "Epoch: 33, Loss: 0.551533\n", "Epoch: 34, Loss: 0.224683\n", "Epoch: 35, Loss: 0.074177\n", "Epoch: 36, Loss: 0.336616\n", "Epoch: 37, Loss: 0.099106\n", "Epoch: 38, Loss: 0.313262\n", "Epoch: 39, Loss: 0.164988\n", "Epoch: 40, Loss: 0.143705\n", "Epoch: 41, Loss: 0.203131\n", "Epoch: 42, Loss: 0.430345\n", "Epoch: 43, Loss: 0.203300\n", "Epoch: 44, Loss: 0.152889\n", "Epoch: 45, Loss: 0.074340\n", "Epoch: 46, Loss: 0.227526\n", "Epoch: 47, Loss: 0.254593\n", "Epoch: 48, Loss: 0.032488\n", "Epoch: 49, Loss: 0.264401\n", "Epoch: 50, Loss: 0.117654\n", "Epoch: 51, Loss: 0.161823\n", "Epoch: 52, Loss: 0.267448\n", "Epoch: 53, Loss: 0.303165\n", "Epoch: 54, Loss: 0.149452\n", "Epoch: 55, Loss: 0.283597\n", "Epoch: 56, Loss: 0.538679\n", "Epoch: 57, Loss: 0.091895\n", "Epoch: 58, Loss: 0.168044\n", "Epoch: 59, Loss: 0.079505\n", "Epoch: 60, Loss: 0.166282\n", "Epoch: 61, Loss: 0.080623\n", "Epoch: 62, Loss: 0.628037\n", "Epoch: 63, Loss: 0.105842\n", "Epoch: 64, Loss: 0.068237\n", "Epoch: 65, Loss: 0.219858\n", "Epoch: 66, Loss: 0.087648\n", "Epoch: 67, Loss: 0.063252\n", "Epoch: 68, Loss: 0.174067\n", "Epoch: 69, Loss: 0.202236\n", "Epoch: 70, Loss: 0.125508\n", "Epoch: 71, Loss: 0.335009\n", "Epoch: 72, Loss: 0.052561\n", "Epoch: 73, Loss: 0.163582\n", "Epoch: 74, Loss: 0.218372\n", "Epoch: 75, Loss: 0.055700\n", "Epoch: 76, Loss: 0.053517\n", "Epoch: 77, Loss: 0.074189\n", "Epoch: 78, Loss: 0.234757\n", "Epoch: 79, Loss: 0.061534\n", "Epoch: 80, Loss: 0.144434\n", "Epoch: 81, Loss: 0.376511\n", "Epoch: 82, Loss: 0.149837\n", "Epoch: 83, Loss: 0.035641\n", "Epoch: 84, Loss: 0.027863\n", "Epoch: 85, Loss: 0.236207\n", "Epoch: 86, Loss: 0.125776\n", "Epoch: 87, Loss: 0.215485\n", "Epoch: 88, Loss: 0.093286\n", "Epoch: 89, Loss: 0.284476\n", "Epoch: 90, Loss: 0.091811\n", "Epoch: 91, Loss: 0.050124\n", "Epoch: 92, Loss: 0.306853\n", "Epoch: 93, Loss: 0.158289\n", "Epoch: 94, Loss: 0.165751\n", "Epoch: 95, Loss: 0.228794\n", "Epoch: 96, Loss: 0.149434\n", "Epoch: 97, Loss: 0.041878\n", "Epoch: 98, Loss: 0.032688\n", "Epoch: 99, Loss: 0.044938\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)\n", "\n", "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)\n", " self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)\n", " self.fc1 = nn.Linear(8 * 8 * 8, 32)\n", " self.fc2 = nn.Linear(32, 2)\n", " \n", " def forward(self, x):\n", " out = F.max_pool2d(torch.relu(self.conv1(x)), 2)\n", " out = F.max_pool2d(torch.relu(self.conv2(out)), 2)\n", " out = out.view(-1, 8 * 8 * 8)\n", " out = torch.tanh(self.fc1(out))\n", " out = self.fc2(out)\n", " return out\n", " \n", "model = Net()\n", "\n", "learning_rate = 1e-2\n", "\n", "optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n", "\n", "loss_fn = nn.CrossEntropyLoss()\n", "\n", "nepochs = 100\n", "\n", "for epoch in range(nepochs):\n", " for imgs, labels in train_loader:\n", " outputs = model(imgs)\n", " loss = loss_fn(outputs, labels)\n", " \n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " print(\"Epoch: %d, Loss: %f\" % (epoch, float(loss)))" ] }, { "cell_type": "code", "execution_count": 277, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.963100\n" ] } ], "source": [ "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=False)\n", "\n", "correct = 0\n", "total = 0\n", "\n", "with torch.no_grad():\n", " for imgs, labels in train_loader:\n", " outputs = model(imgs)\n", " _, predicted = torch.max(outputs, dim=1)\n", " total += labels.shape[0]\n", " correct += int((predicted == labels).sum())\n", " \n", "print(\"Accuracy: %f\" % (correct / total))" ] }, { "cell_type": "code", "execution_count": 276, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.889500\n" ] } ], "source": [ "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)\n", "\n", "correct = 0\n", "total = 0\n", "\n", "with torch.no_grad():\n", " for imgs, labels in val_loader:\n", " outputs = model(imgs)\n", " _, predicted = torch.max(outputs, dim=1)\n", " total += labels.shape[0]\n", " correct += int((predicted == labels).sum())\n", " \n", "print(\"Accuracy: %f\" % (correct / total))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }