{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%matplotlib inline\n", "from matplotlib import pyplot as plt\n", "import numpy as np\n", "import torch\n", "\n", "torch.set_printoptions(edgeitems=2)\n", "torch.manual_seed(123)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class_names = ['airplane','automobile','bird','cat','deer',\n", " 'dog','frog','horse','ship','truck']" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from torchvision import datasets, transforms\n", "data_path = '../data-unversioned/p1ch7/'\n", "cifar10 = datasets.CIFAR10(\n", " data_path, train=True, download=False,\n", " transform=transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.4915, 0.4823, 0.4468),\n", " (0.2470, 0.2435, 0.2616))\n", " ]))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "cifar10_val = datasets.CIFAR10(\n", " data_path, train=False, download=False,\n", " transform=transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.4915, 0.4823, 0.4468),\n", " (0.2470, 0.2435, 0.2616))\n", " ]))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "label_map = {0: 0, 2: 1}\n", "class_names = ['airplane', 'bird']\n", "cifar2 = [(img, label_map[label])\n", " for img, label in cifar10 \n", " if label in [0, 2]]\n", "cifar2_val = [(img, label_map[label])\n", " for img, label in cifar10_val\n", " if label in [0, 2]]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "\n", "n_out = 2\n", "\n", "model = nn.Sequential(\n", " nn.Linear(\n", " 3072, # <1>\n", " 512, # <2>\n", " ),\n", " nn.Tanh(),\n", " nn.Linear(\n", " 512, # <2>\n", " n_out, # <3>\n", " )\n", " )" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def softmax(x):\n", " return torch.exp(x) / torch.exp(x).sum()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.0900, 0.2447, 0.6652])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = torch.tensor([1.0, 2.0, 3.0])\n", "\n", "softmax(x)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(1.)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "softmax(x).sum()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.0900, 0.2447, 0.6652],\n", " [0.0900, 0.2447, 0.6652]])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "softmax = nn.Softmax(dim=1)\n", "\n", "x = torch.tensor([[1.0, 2.0, 3.0],\n", " [1.0, 2.0, 3.0]])\n", "\n", "softmax(x)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " nn.Linear(3072, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 2),\n", " nn.Softmax(dim=1))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "img, _ = cifar2[0]\n", "\n", "plt.imshow(img.permute(1, 2, 0))\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "img_batch = img.view(-1).unsqueeze(0)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.4784, 0.5216]], grad_fn=)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out = model(img_batch)\n", "out" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([1])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "_, index = torch.max(out, dim=1)\n", "\n", "index" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[1., 0.],\n", " [1., 0.],\n", " [0., 1.],\n", " [0., 1.]])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out = torch.tensor([\n", " [0.6, 0.4],\n", " [0.9, 0.1],\n", " [0.3, 0.7],\n", " [0.2, 0.8],\n", "])\n", "class_index = torch.tensor([0, 0, 1, 1]).unsqueeze(1)\n", "\n", "truth = torch.zeros((4,2))\n", "truth.scatter_(dim=1, index=class_index, value=1.0)\n", "truth" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.1500)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def mse(out):\n", " return ((out - truth) ** 2).sum(dim=1).mean()\n", "mse(out)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.6000],\n", " [0.9000],\n", " [0.7000],\n", " [0.8000]])" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out.gather(dim=1, index=class_index)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.3024])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def likelihood(out):\n", " prod = 1.0\n", " for x in out.gather(dim=1, index=class_index):\n", " prod *= x\n", " return prod\n", "\n", "likelihood(out)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([1.1960])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def neg_log_likelihood(out):\n", " return -likelihood(out).log()\n", "\n", "neg_log_likelihood(out)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.0750, 0.1500, 0.2500, 0.4750])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out0 = out.clone().detach()\n", "out0[0] = torch.tensor([0.9, 0.1]) # more right\n", "\n", "out2 = out.clone().detach()\n", "out2[0] = torch.tensor([0.4, 0.6]) # slightly wrong\n", "\n", "out3 = out.clone().detach()\n", "out3[0] = torch.tensor([0.1, 0.9]) # very wrong\n", "\n", "mse_comparison = torch.tensor([mse(o) for o in [out0, out, out2, out3]])\n", "mse_comparison" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-50.0000, 0.0000, 66.6667, 216.6667])" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "((mse_comparison / mse_comparison[1]) - 1) * 100" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0.7905, 1.1960, 1.6015, 2.9878])" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nll_comparison = torch.tensor([neg_log_likelihood(o) \n", " for o in [out0, out, out2, out3]])\n", "nll_comparison" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-33.9016, 0.0000, 33.9016, 149.8121])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "((nll_comparison / nll_comparison[1]) - 1) * 100" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 1.]])" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "softmax = nn.Softmax(dim=1)\n", "\n", "log_softmax = nn.LogSoftmax(dim=1)\n", "\n", "x = torch.tensor([[0.0, 104.0]])\n", "\n", "softmax(x)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 1.]])" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "softmax = nn.Softmax(dim=1)\n", "\n", "log_softmax = nn.LogSoftmax(dim=1)\n", "\n", "x = torch.tensor([[0.0, 104.0]])\n", "\n", "softmax(x)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-inf, 0.]])" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.log(softmax(x))" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-104., 0.]])" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "log_softmax(x)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 1.]])" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.exp(log_softmax(x))" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " nn.Linear(3072, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 2),\n", " nn.LogSoftmax(dim=1))" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "loss = nn.NLLLoss()" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.5077, grad_fn=)" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img, label = cifar2[0]\n", "\n", "out = model(img.view(-1).unsqueeze(0))\n", "\n", "loss(out, torch.tensor([label]))" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0, Loss: 5.347057\n", "Epoch: 1, Loss: 7.705317\n", "Epoch: 2, Loss: 6.510838\n", "Epoch: 3, Loss: 9.557189\n", "Epoch: 4, Loss: 4.151933\n", "Epoch: 5, Loss: 5.636873\n", "Epoch: 6, Loss: 6.531207\n", "Epoch: 7, Loss: 20.450516\n", "Epoch: 8, Loss: 5.072948\n", "Epoch: 9, Loss: 4.941860\n", "Epoch: 10, Loss: 6.445535\n", "Epoch: 11, Loss: 4.580799\n", "Epoch: 12, Loss: 6.660308\n", "Epoch: 13, Loss: 9.436373\n", "Epoch: 14, Loss: 16.786476\n", "Epoch: 15, Loss: 8.349138\n", "Epoch: 16, Loss: 8.176860\n", "Epoch: 17, Loss: 5.862664\n", "Epoch: 18, Loss: 8.218906\n", "Epoch: 19, Loss: 13.296558\n", "Epoch: 20, Loss: 7.313433\n", "Epoch: 21, Loss: 4.585245\n", "Epoch: 22, Loss: 11.706884\n", "Epoch: 23, Loss: 18.208710\n", "Epoch: 24, Loss: 0.343157\n", "Epoch: 25, Loss: 9.255491\n", "Epoch: 26, Loss: 10.466807\n", "Epoch: 27, Loss: 12.226366\n", "Epoch: 28, Loss: 12.728527\n", "Epoch: 29, Loss: 9.777843\n", "Epoch: 30, Loss: 6.128856\n", "Epoch: 31, Loss: 13.284330\n", "Epoch: 32, Loss: 10.321814\n", "Epoch: 33, Loss: 2.928349\n", "Epoch: 34, Loss: 8.623670\n", "Epoch: 35, Loss: 12.719531\n", "Epoch: 36, Loss: 4.030444\n", "Epoch: 37, Loss: 4.621825\n", "Epoch: 38, Loss: 13.210777\n", "Epoch: 39, Loss: 14.217413\n", "Epoch: 40, Loss: 3.880259\n", "Epoch: 41, Loss: 13.189833\n", "Epoch: 42, Loss: 17.787762\n", "Epoch: 43, Loss: 3.953930\n", "Epoch: 44, Loss: 0.640078\n", "Epoch: 45, Loss: 9.262226\n", "Epoch: 46, Loss: 7.383645\n", "Epoch: 47, Loss: 5.352252\n", "Epoch: 48, Loss: 11.515299\n", "Epoch: 49, Loss: 12.266010\n", "Epoch: 50, Loss: 12.210896\n", "Epoch: 51, Loss: 3.987965\n", "Epoch: 52, Loss: 12.570765\n", "Epoch: 53, Loss: 13.025002\n", "Epoch: 54, Loss: 13.747946\n", "Epoch: 55, Loss: 6.783926\n", "Epoch: 56, Loss: 11.822943\n", "Epoch: 57, Loss: 8.200066\n", "Epoch: 58, Loss: 9.206728\n", "Epoch: 59, Loss: 7.715425\n", "Epoch: 60, Loss: 5.571069\n", "Epoch: 61, Loss: 13.017315\n", "Epoch: 62, Loss: 10.307802\n", "Epoch: 63, Loss: 2.660404\n", "Epoch: 64, Loss: 11.096642\n", "Epoch: 65, Loss: 5.284830\n", "Epoch: 66, Loss: 8.374750\n", "Epoch: 67, Loss: 1.418676\n", "Epoch: 68, Loss: 9.891462\n", "Epoch: 69, Loss: 9.079073\n", "Epoch: 70, Loss: 6.453581\n", "Epoch: 71, Loss: 8.293860\n", "Epoch: 72, Loss: 4.585221\n", "Epoch: 73, Loss: 14.174129\n", "Epoch: 74, Loss: 6.072280\n", "Epoch: 75, Loss: 5.925417\n", "Epoch: 76, Loss: 0.260600\n", "Epoch: 77, Loss: 3.055498\n", "Epoch: 78, Loss: 0.347163\n", "Epoch: 79, Loss: 3.497080\n", "Epoch: 80, Loss: 6.615281\n", "Epoch: 81, Loss: 8.944511\n", "Epoch: 82, Loss: 10.230938\n", "Epoch: 83, Loss: 6.776264\n", "Epoch: 84, Loss: 10.169885\n", "Epoch: 85, Loss: 7.014330\n", "Epoch: 86, Loss: 3.467798\n", "Epoch: 87, Loss: 3.772486\n", "Epoch: 88, Loss: 13.495383\n", "Epoch: 89, Loss: 11.781836\n", "Epoch: 90, Loss: 6.853724\n", "Epoch: 91, Loss: 3.313806\n", "Epoch: 92, Loss: 7.867707\n", "Epoch: 93, Loss: 16.117371\n", "Epoch: 94, Loss: 15.077475\n", "Epoch: 95, Loss: 17.807060\n", "Epoch: 96, Loss: 16.376089\n", "Epoch: 97, Loss: 9.348265\n", "Epoch: 98, Loss: 18.044790\n", "Epoch: 99, Loss: 15.565783\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "\n", "model = nn.Sequential(\n", " nn.Linear(3072, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 2),\n", " nn.LogSoftmax(dim=1))\n", "\n", "learning_rate = 1e-2\n", "\n", "optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n", "\n", "loss_fn = nn.NLLLoss()\n", "\n", "n_epochs = 100\n", "\n", "for epoch in range(n_epochs):\n", " for img, label in cifar2:\n", " out = model(img.view(-1).unsqueeze(0))\n", " loss = loss_fn(out, torch.tensor([label]))\n", " \n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " print(\"Epoch: %d, Loss: %f\" % (epoch, float(loss)))" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n", " shuffle=True)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0, Loss: 0.604063\n", "Epoch: 1, Loss: 0.597974\n", "Epoch: 2, Loss: 0.271415\n", "Epoch: 3, Loss: 0.451056\n", "Epoch: 4, Loss: 0.629758\n", "Epoch: 5, Loss: 0.458762\n", "Epoch: 6, Loss: 0.277813\n", "Epoch: 7, Loss: 0.406921\n", "Epoch: 8, Loss: 0.951961\n", "Epoch: 9, Loss: 0.433738\n", "Epoch: 10, Loss: 0.351960\n", "Epoch: 11, Loss: 0.355687\n", "Epoch: 12, Loss: 0.518611\n", "Epoch: 13, Loss: 0.262623\n", "Epoch: 14, Loss: 0.221969\n", "Epoch: 15, Loss: 0.774132\n", "Epoch: 16, Loss: 0.324406\n", "Epoch: 17, Loss: 0.447701\n", "Epoch: 18, Loss: 0.299780\n", "Epoch: 19, Loss: 0.267090\n", "Epoch: 20, Loss: 0.279828\n", "Epoch: 21, Loss: 0.197123\n", "Epoch: 22, Loss: 0.196783\n", "Epoch: 23, Loss: 0.328715\n", "Epoch: 24, Loss: 0.334952\n", "Epoch: 25, Loss: 0.500689\n", "Epoch: 26, Loss: 0.186956\n", "Epoch: 27, Loss: 0.138649\n", "Epoch: 28, Loss: 0.239988\n", "Epoch: 29, Loss: 0.495020\n", "Epoch: 30, Loss: 0.251347\n", "Epoch: 31, Loss: 0.088298\n", "Epoch: 32, Loss: 0.175127\n", "Epoch: 33, Loss: 0.208338\n", "Epoch: 34, Loss: 0.145656\n", "Epoch: 35, Loss: 0.129570\n", "Epoch: 36, Loss: 0.200110\n", "Epoch: 37, Loss: 0.133076\n", "Epoch: 38, Loss: 0.230561\n", "Epoch: 39, Loss: 0.241688\n", "Epoch: 40, Loss: 0.106870\n", "Epoch: 41, Loss: 0.281168\n", "Epoch: 42, Loss: 0.175034\n", "Epoch: 43, Loss: 0.073779\n", "Epoch: 44, Loss: 0.171294\n", "Epoch: 45, Loss: 0.112456\n", "Epoch: 46, Loss: 0.132553\n", "Epoch: 47, Loss: 0.048826\n", "Epoch: 48, Loss: 0.076014\n", "Epoch: 49, Loss: 0.122317\n", "Epoch: 50, Loss: 0.103442\n", "Epoch: 51, Loss: 0.201585\n", "Epoch: 52, Loss: 0.145637\n", "Epoch: 53, Loss: 0.055844\n", "Epoch: 54, Loss: 0.046278\n", "Epoch: 55, Loss: 0.081562\n", "Epoch: 56, Loss: 0.058857\n", "Epoch: 57, Loss: 0.197200\n", "Epoch: 58, Loss: 0.044184\n", "Epoch: 59, Loss: 0.043374\n", "Epoch: 60, Loss: 0.032936\n", "Epoch: 61, Loss: 0.072488\n", "Epoch: 62, Loss: 0.060811\n", "Epoch: 63, Loss: 0.029262\n", "Epoch: 64, Loss: 0.036435\n", "Epoch: 65, Loss: 0.058120\n", "Epoch: 66, Loss: 0.063329\n", "Epoch: 67, Loss: 0.020670\n", "Epoch: 68, Loss: 0.077189\n", "Epoch: 69, Loss: 0.060933\n", "Epoch: 70, Loss: 0.070848\n", "Epoch: 71, Loss: 0.036434\n", "Epoch: 72, Loss: 0.084855\n", "Epoch: 73, Loss: 0.044776\n", "Epoch: 74, Loss: 0.037828\n", "Epoch: 75, Loss: 0.024554\n", "Epoch: 76, Loss: 0.018965\n", "Epoch: 77, Loss: 0.033381\n", "Epoch: 78, Loss: 0.016183\n", "Epoch: 79, Loss: 0.020083\n", "Epoch: 80, Loss: 0.041192\n", "Epoch: 81, Loss: 0.015122\n", "Epoch: 82, Loss: 0.014245\n", "Epoch: 83, Loss: 0.018538\n", "Epoch: 84, Loss: 0.044791\n", "Epoch: 85, Loss: 0.034532\n", "Epoch: 86, Loss: 0.010175\n", "Epoch: 87, Loss: 0.021837\n", "Epoch: 88, Loss: 0.005545\n", "Epoch: 89, Loss: 0.012682\n", "Epoch: 90, Loss: 0.026414\n", "Epoch: 91, Loss: 0.021372\n", "Epoch: 92, Loss: 0.025901\n", "Epoch: 93, Loss: 0.025262\n", "Epoch: 94, Loss: 0.047044\n", "Epoch: 95, Loss: 0.016064\n", "Epoch: 96, Loss: 0.059213\n", "Epoch: 97, Loss: 0.017386\n", "Epoch: 98, Loss: 0.016215\n", "Epoch: 99, Loss: 0.016987\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "\n", "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n", " shuffle=True)\n", "\n", "model = nn.Sequential(\n", " nn.Linear(3072, 128),\n", " nn.Tanh(),\n", " nn.Linear(128, 2),\n", " nn.LogSoftmax(dim=1))\n", "\n", "learning_rate = 1e-2\n", "\n", "optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n", "\n", "loss_fn = nn.NLLLoss()\n", "\n", "n_epochs = 100\n", "\n", "for epoch in range(n_epochs):\n", " for imgs, labels in train_loader:\n", " outputs = model(imgs.view(imgs.shape[0], -1))\n", " loss = loss_fn(outputs, labels)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " print(\"Epoch: %d, Loss: %f\" % (epoch, float(loss)))" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0, Loss: 0.732168\n", "Epoch: 1, Loss: 0.348352\n", "Epoch: 2, Loss: 0.318960\n", "Epoch: 3, Loss: 0.313264\n", "Epoch: 4, Loss: 0.378358\n", "Epoch: 5, Loss: 0.276529\n", "Epoch: 6, Loss: 0.443889\n", "Epoch: 7, Loss: 0.436946\n", "Epoch: 8, Loss: 0.324288\n", "Epoch: 9, Loss: 0.274647\n", "Epoch: 10, Loss: 0.291681\n", "Epoch: 11, Loss: 0.242894\n", "Epoch: 12, Loss: 0.301849\n", "Epoch: 13, Loss: 0.202063\n", "Epoch: 14, Loss: 0.389276\n", "Epoch: 15, Loss: 0.167129\n", "Epoch: 16, Loss: 0.135282\n", "Epoch: 17, Loss: 0.385485\n", "Epoch: 18, Loss: 0.453852\n", "Epoch: 19, Loss: 0.641304\n", "Epoch: 20, Loss: 0.287667\n", "Epoch: 21, Loss: 0.337029\n", "Epoch: 22, Loss: 0.393282\n", "Epoch: 23, Loss: 0.409480\n", "Epoch: 24, Loss: 0.138473\n", "Epoch: 25, Loss: 0.690729\n", "Epoch: 26, Loss: 0.572156\n", "Epoch: 27, Loss: 0.078534\n", "Epoch: 28, Loss: 0.324833\n", "Epoch: 29, Loss: 0.262829\n", "Epoch: 30, Loss: 0.430449\n", "Epoch: 31, Loss: 0.071872\n", "Epoch: 32, Loss: 0.058039\n", "Epoch: 33, Loss: 0.052903\n", "Epoch: 34, Loss: 0.065879\n", "Epoch: 35, Loss: 0.107696\n", "Epoch: 36, Loss: 0.305224\n", "Epoch: 37, Loss: 0.098637\n", "Epoch: 38, Loss: 0.139823\n", "Epoch: 39, Loss: 0.226455\n", "Epoch: 40, Loss: 0.117763\n", "Epoch: 41, Loss: 0.106498\n", "Epoch: 42, Loss: 0.086254\n", "Epoch: 43, Loss: 0.135652\n", "Epoch: 44, Loss: 0.070890\n", "Epoch: 45, Loss: 0.304346\n", "Epoch: 46, Loss: 0.016917\n", "Epoch: 47, Loss: 0.057929\n", "Epoch: 48, Loss: 0.131021\n", "Epoch: 49, Loss: 0.136299\n", "Epoch: 50, Loss: 0.048885\n", "Epoch: 51, Loss: 0.241048\n", "Epoch: 52, Loss: 0.092595\n", "Epoch: 53, Loss: 0.059137\n", "Epoch: 54, Loss: 0.047421\n", "Epoch: 55, Loss: 0.102036\n", "Epoch: 56, Loss: 0.023338\n", "Epoch: 57, Loss: 0.054306\n", "Epoch: 58, Loss: 0.073878\n", "Epoch: 59, Loss: 0.031387\n", "Epoch: 60, Loss: 0.039865\n", "Epoch: 61, Loss: 0.022344\n", "Epoch: 62, Loss: 0.052310\n", "Epoch: 63, Loss: 0.059688\n", "Epoch: 64, Loss: 0.023977\n", "Epoch: 65, Loss: 0.010632\n", "Epoch: 66, Loss: 0.039090\n", "Epoch: 67, Loss: 0.080844\n", "Epoch: 68, Loss: 0.029650\n", "Epoch: 69, Loss: 0.027038\n", "Epoch: 70, Loss: 0.028515\n", "Epoch: 71, Loss: 0.021998\n", "Epoch: 72, Loss: 0.014992\n", "Epoch: 73, Loss: 0.019659\n", "Epoch: 74, Loss: 0.025150\n", "Epoch: 75, Loss: 0.017384\n", "Epoch: 76, Loss: 0.013249\n", "Epoch: 77, Loss: 0.009451\n", "Epoch: 78, Loss: 0.034637\n", "Epoch: 79, Loss: 0.114242\n", "Epoch: 80, Loss: 0.019007\n", "Epoch: 81, Loss: 0.016319\n", "Epoch: 82, Loss: 0.027428\n", "Epoch: 83, Loss: 0.022366\n", "Epoch: 84, Loss: 0.022583\n", "Epoch: 85, Loss: 0.006275\n", "Epoch: 86, Loss: 0.011964\n", "Epoch: 87, Loss: 0.018711\n", "Epoch: 88, Loss: 0.019636\n", "Epoch: 89, Loss: 0.018975\n", "Epoch: 90, Loss: 0.023520\n", "Epoch: 91, Loss: 0.016398\n", "Epoch: 92, Loss: 0.006638\n", "Epoch: 93, Loss: 0.013305\n", "Epoch: 94, Loss: 0.017126\n", "Epoch: 95, Loss: 0.021641\n", "Epoch: 96, Loss: 0.036945\n", "Epoch: 97, Loss: 0.004735\n", "Epoch: 98, Loss: 0.016781\n", "Epoch: 99, Loss: 0.012039\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "\n", "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n", " shuffle=True)\n", "\n", "model = nn.Sequential(\n", " nn.Linear(3072, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 2),\n", " nn.LogSoftmax(dim=1))\n", "\n", "learning_rate = 1e-2\n", "\n", "optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n", "\n", "loss_fn = nn.NLLLoss()\n", "\n", "n_epochs = 100\n", "\n", "for epoch in range(n_epochs):\n", " for imgs, labels in train_loader:\n", " outputs = model(imgs.view(imgs.shape[0], -1))\n", " loss = loss_fn(outputs, labels)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " print(\"Epoch: %d, Loss: %f\" % (epoch, float(loss)))" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.997700\n" ] } ], "source": [ "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n", " shuffle=False)\n", "\n", "correct = 0\n", "total = 0\n", "\n", "with torch.no_grad():\n", " for imgs, labels in train_loader:\n", " outputs = model(imgs.view(imgs.shape[0], -1))\n", " _, predicted = torch.max(outputs, dim=1)\n", " total += labels.shape[0]\n", " correct += int((predicted == labels).sum())\n", " \n", "print(\"Accuracy: %f\" % (correct / total))" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.821000\n" ] } ], "source": [ "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64,\n", " shuffle=False)\n", "\n", "correct = 0\n", "total = 0\n", "\n", "with torch.no_grad():\n", " for imgs, labels in val_loader:\n", " outputs = model(imgs.view(imgs.shape[0], -1))\n", " _, predicted = torch.max(outputs, dim=1)\n", " total += labels.shape[0]\n", " correct += int((predicted == labels).sum())\n", " \n", "print(\"Accuracy: %f\" % (correct / total))" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " nn.Linear(3072, 1024),\n", " nn.Tanh(),\n", " nn.Linear(1024, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 128),\n", " nn.Tanh(),\n", " nn.Linear(128, 2),\n", " nn.LogSoftmax(dim=1))" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " nn.Linear(3072, 1024),\n", " nn.Tanh(),\n", " nn.Linear(1024, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 128),\n", " nn.Tanh(),\n", " nn.Linear(128, 2))\n", "\n", "loss_fn = nn.CrossEntropyLoss()" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0, Loss: 0.641261\n", "Epoch: 1, Loss: 0.525149\n", "Epoch: 2, Loss: 0.466143\n", "Epoch: 3, Loss: 0.451913\n", "Epoch: 4, Loss: 0.343860\n", "Epoch: 5, Loss: 0.309738\n", "Epoch: 6, Loss: 0.485261\n", "Epoch: 7, Loss: 0.283789\n", "Epoch: 8, Loss: 0.301561\n", "Epoch: 9, Loss: 0.408200\n", "Epoch: 10, Loss: 0.346715\n", "Epoch: 11, Loss: 0.358134\n", "Epoch: 12, Loss: 0.388485\n", "Epoch: 13, Loss: 0.378096\n", "Epoch: 14, Loss: 0.518019\n", "Epoch: 15, Loss: 0.359279\n", "Epoch: 16, Loss: 0.420371\n", "Epoch: 17, Loss: 0.366249\n", "Epoch: 18, Loss: 0.282639\n", "Epoch: 19, Loss: 0.468854\n", "Epoch: 20, Loss: 0.467920\n", "Epoch: 21, Loss: 0.237441\n", "Epoch: 22, Loss: 0.243472\n", "Epoch: 23, Loss: 0.566929\n", "Epoch: 24, Loss: 0.316143\n", "Epoch: 25, Loss: 0.336322\n", "Epoch: 26, Loss: 0.473064\n", "Epoch: 27, Loss: 0.407040\n", "Epoch: 28, Loss: 0.252989\n", "Epoch: 29, Loss: 0.195740\n", "Epoch: 30, Loss: 0.663084\n", "Epoch: 31, Loss: 0.659899\n", "Epoch: 32, Loss: 0.285113\n", "Epoch: 33, Loss: 0.212042\n", "Epoch: 34, Loss: 0.324017\n", "Epoch: 35, Loss: 0.097063\n", "Epoch: 36, Loss: 0.181754\n", "Epoch: 37, Loss: 0.091362\n", "Epoch: 38, Loss: 0.069348\n", "Epoch: 39, Loss: 0.085656\n", "Epoch: 40, Loss: 0.163399\n", "Epoch: 41, Loss: 0.064912\n", "Epoch: 42, Loss: 0.046740\n", "Epoch: 43, Loss: 0.029891\n", "Epoch: 44, Loss: 0.018157\n", "Epoch: 45, Loss: 0.103532\n", "Epoch: 46, Loss: 0.161911\n", "Epoch: 47, Loss: 0.238185\n", "Epoch: 48, Loss: 0.081116\n", "Epoch: 49, Loss: 0.040988\n", "Epoch: 50, Loss: 0.008668\n", "Epoch: 51, Loss: 0.012557\n", "Epoch: 52, Loss: 0.015967\n", "Epoch: 53, Loss: 0.020964\n", "Epoch: 54, Loss: 0.023478\n", "Epoch: 55, Loss: 0.012850\n", "Epoch: 56, Loss: 0.054703\n", "Epoch: 57, Loss: 0.014922\n", "Epoch: 58, Loss: 0.045488\n", "Epoch: 59, Loss: 0.122221\n", "Epoch: 60, Loss: 0.028012\n", "Epoch: 61, Loss: 0.029533\n", "Epoch: 62, Loss: 0.004758\n", "Epoch: 63, Loss: 0.080409\n", "Epoch: 64, Loss: 0.005409\n", "Epoch: 65, Loss: 0.020399\n", "Epoch: 66, Loss: 0.008184\n", "Epoch: 67, Loss: 0.013888\n", "Epoch: 68, Loss: 0.002199\n", "Epoch: 69, Loss: 0.001918\n", "Epoch: 70, Loss: 0.018765\n", "Epoch: 71, Loss: 0.004223\n", "Epoch: 72, Loss: 0.001795\n", "Epoch: 73, Loss: 0.102238\n", "Epoch: 74, Loss: 0.002482\n", "Epoch: 75, Loss: 0.005807\n", "Epoch: 76, Loss: 0.001742\n", "Epoch: 77, Loss: 0.012760\n", "Epoch: 78, Loss: 0.017469\n", "Epoch: 79, Loss: 0.002849\n", "Epoch: 80, Loss: 0.001452\n", "Epoch: 81, Loss: 0.002740\n", "Epoch: 82, Loss: 0.003317\n", "Epoch: 83, Loss: 0.002066\n", "Epoch: 84, Loss: 0.001952\n", "Epoch: 85, Loss: 0.010757\n", "Epoch: 86, Loss: 0.004866\n", "Epoch: 87, Loss: 0.003957\n", "Epoch: 88, Loss: 0.001295\n", "Epoch: 89, Loss: 0.004410\n", "Epoch: 90, Loss: 0.002952\n", "Epoch: 91, Loss: 0.000676\n", "Epoch: 92, Loss: 0.001835\n", "Epoch: 93, Loss: 0.000739\n", "Epoch: 94, Loss: 0.001102\n", "Epoch: 95, Loss: 0.000792\n", "Epoch: 96, Loss: 0.000515\n", "Epoch: 97, Loss: 0.001548\n", "Epoch: 98, Loss: 0.026913\n", "Epoch: 99, Loss: 0.000140\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "\n", "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n", " shuffle=True)\n", "\n", "model = nn.Sequential(\n", " nn.Linear(3072, 1024),\n", " nn.Tanh(),\n", " nn.Linear(1024, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 128),\n", " nn.Tanh(),\n", " nn.Linear(128, 2))\n", "\n", "learning_rate = 1e-2\n", "\n", "optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n", "\n", "loss_fn = nn.CrossEntropyLoss()\n", "\n", "n_epochs = 100\n", "\n", "for epoch in range(n_epochs):\n", " for imgs, labels in train_loader:\n", " outputs = model(imgs.view(imgs.shape[0], -1))\n", " loss = loss_fn(outputs, labels)\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " print(\"Epoch: %d, Loss: %f\" % (epoch, float(loss)))" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.999700\n" ] } ], "source": [ "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n", " shuffle=False)\n", "\n", "correct = 0\n", "total = 0\n", "\n", "with torch.no_grad():\n", " for imgs, labels in train_loader:\n", " outputs = model(imgs.view(imgs.shape[0], -1))\n", " _, predicted = torch.max(outputs, dim=1)\n", " total += labels.shape[0]\n", " correct += int((predicted == labels).sum())\n", " \n", "print(\"Accuracy: %f\" % (correct / total))" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.801000\n" ] } ], "source": [ "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64,\n", " shuffle=False)\n", "\n", "correct = 0\n", "total = 0\n", "\n", "with torch.no_grad():\n", " for imgs, labels in val_loader:\n", " outputs = model(imgs.view(imgs.shape[0], -1))\n", " _, predicted = torch.max(outputs, dim=1)\n", " total += labels.shape[0]\n", " correct += int((predicted == labels).sum())\n", " \n", "print(\"Accuracy: %f\" % (correct / total))" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3737474" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum([p.numel() for p in model.parameters()])" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3737474" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum([p.numel() for p in model.parameters() if p.requires_grad == True])" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1574402" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "first_model = nn.Sequential(\n", " nn.Linear(3072, 512),\n", " nn.Tanh(),\n", " nn.Linear(512, 2),\n", " nn.LogSoftmax(dim=1))\n", "\n", "sum([p.numel() for p in first_model.parameters()])" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1573376" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum([p.numel() for p in nn.Linear(3072, 512).parameters()])" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3146752" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum([p.numel() for p in nn.Linear(3072, 1024).parameters()])" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1024, 3072]), torch.Size([1024]))" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "linear = nn.Linear(3072, 1024)\n", "\n", "linear.weight.shape, linear.bias.shape" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "conv = nn.Conv2d(3, 16, kernel_size=3)" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([16, 3, 3, 3])" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "conv.weight.shape" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([16])" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "conv.bias.shape" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "img, _ = cifar2[0]\n", "\n", "output = conv(img.unsqueeze(0))" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([1, 3, 32, 32]), torch.Size([1, 16, 30, 30]))" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img.unsqueeze(0).shape, output.shape" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.imshow(img.permute(1, 2, 0), cmap='gray')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.imshow(output[0, 0].detach(), cmap='gray')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 16, 30, 30])" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output.shape" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "conv = nn.Conv2d(3, 1, kernel_size=3, padding=1)" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 1, 32, 32])" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output = conv(img.unsqueeze(0))\n", "\n", "output.shape" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "with torch.no_grad():\n", " conv.bias.zero_()" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [], "source": [ "with torch.no_grad():\n", " conv.weight.fill_(1.0 / 9.0)" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAF79JREFUeJztnV+MXdV1xn8L4z/YGBvb2Bjj1BD5IVFUSDRCkaiiNGkjGlUikZooeYh4QHFUgdRI6QOiUkOlPiRVkygPVSqnoJAqhND8UVAVtUEoFcoLiUMJfwIFgt1gPLGNwdgQAni8+nCv1WFy1zd3zsyca7K/nzSaO2fdfc4++5xv7r37u2vtyEyMMe1xzqQ7YIyZDBa/MY1i8RvTKBa/MY1i8RvTKBa/MY1i8RvTKBa/MY1i8RvTKOcupnFEXAN8GVgB/Etmfk49f926dblx48aRMfVNw5mZmZHbzzmn/t+1YsWKMta13bnnjh4utT91XqdPn+4U65OIWNJY1/11Hcfq3lH76xrrej2rmGpTjdWJEyd45ZVX6oGcRWfxR8QK4J+APwUOAj+NiLsz8xdVm40bN3LDDTeMjP32t78tj/XSSy+N3L5mzZqyzQUXXFDG1q1bV8Y2bNhQxjZt2rTg/b3++utlrDovgN/85jdlbKlR//BWrlxZxtQ/vVWrVo3cvnr16k7HOnXqVBk7ceJEGavG+NVXXy3bVP8wAF577bUypq6ZilX3/iuvvFK2qV6I7rjjjrLNXBbztv8q4KnMfDozXwPuBK5dxP6MMT2yGPHvAJ6Z9ffB4TZjzJuAxYh/1OeK3/lAFBF7ImJfROx7+eWXF3E4Y8xSshjxHwR2zvr7UuDQ3Cdl5t7MnMrMKfXZ2BjTL4sR/0+B3RFxWUSsAj4G3L003TLGLDedZ/sz81RE3Aj8JwOr77bMfFS1iYhyllLN9Faz+uedd96C24C2a9TsfDW7ff7553c6VjUWoMdDWUDV8dSMvhpH9W6tmtGHelZfzfar8VBukJqdr2b71RiqfnS189R9VTkBysWoxn4hFvGifP7M/AHwg8XswxgzGfwNP2MaxeI3plEsfmMaxeI3plEsfmMaZVGz/V2obBmV0aWsqIquSTPKmjt27NjI7ZdddlnZpkoGAp0kor4NqWyvytJTdqSyttTYq3aVLaoSdNQ90DWx5+jRoyO3q6QZlRSm7g+VLKRi1f2o7tMumYBz8Su/MY1i8RvTKBa/MY1i8RvTKBa/MY3S62z/zMwMJ0+eHBlTSSLVTHXXBB2VCKJm2asZW1XOqjpf0P1Xs9GKtWvXjtyuEnuqNqATpNQMfOVkdF0VWs1iq5n0qh+qTdd7p8uMPtTOlHKsKmdkIePrV35jGsXiN6ZRLH5jGsXiN6ZRLH5jGsXiN6ZRerf6XnjhhZExZUVVdk2XxBLQySpdln568cUXyzbKelGJLKofqv/VOKrxUCjbSyUYVbaXus7KKlMWbBdLTPVDWZhdk3e6jJWyFdW9My5+5TemUSx+YxrF4jemUSx+YxrF4jemUSx+YxplUVZfRBwATgIzwKnMnFLPn5mZKW0ZZWtUVp/KOFM24IYNG8qYqrlXLTWlbCN1XspGUxZhl2xGZZWpGnhda+5V1qIajyNHjpSxZ555pozt37+/jFXHU/eHyqhU/Vd2nhrHCtXHKrYQC3ApfP4/zsznlmA/xpge8dt+YxplseJP4IcR8bOI2LMUHTLG9MNi3/ZfnZmHImIrcE9EPJ6Z981+wvCfwh7Qyz0bY/plUa/8mXlo+PsI8D3gqhHP2ZuZU5k5pdZmN8b0S2fxR8S6iFh/5jHwAeCRpeqYMWZ5Wczb/m3A94bWwrnAHZn5H6pBZpa2nbJJKrrYJ6Az5lQxy8rSU8VHV65cWcaULaPsJpXFVll9aqyUValsRZVdWJ23WqLs0KFDZezRRx/t1K76qHnhhReWbRTqmikbUMWq+1FdF2UDjkvnPWTm08AVi+6BMWYi2OozplEsfmMaxeI3plEsfmMaxeI3plF6LeCZmaXloSyg6stBKitOWWVq3bRqXUCobS9lh6nMQ2U5dikkCrU9tBQFHxdC1X9l6arsQmVvdrExlT2rYupaq36o+7G6j1XR1Sq2kOvsV35jGsXiN6ZRLH5jGsXiN6ZRLH5jGqX32f5qaSI1m1vN9qvZUDVbrmb71Wx01e75558v26gaBiqmZpxVanS1T1W38IILLuh0LDXzXTkSanzVNduyZUsZ27ZtWxm79NJLR25X56X6eOzYsU7tutQFVPfAUuBXfmMaxeI3plEsfmMaxeI3plEsfmMaxeI3plF6t/qq5Adl9VX14LokuIBOwFD2YZWAcfLkybJNl4QOgIsuuqiMKWuuqk2n2qhafIrKtoV6jJUtqpK7lFW5devWMrZr166R29W9c/z48TKm+qiSuJStW/VF7a+yKp3YY4yZF4vfmEax+I1pFIvfmEax+I1pFIvfmEaZ1+qLiNuAPweOZOY7hts2Ad8CdgEHgI9m5gvz7Utl9SkLRVlpFV0XBe1SO08tyaWsQ7XkkrKGNm7cWMbWr18/cnvX8VBWpbLtqnp8amktVcNP1VaszlnF1HXpYrHNF1OZpJXl26XuoroX5zLOK//XgGvmbLsJuDczdwP3Dv82xryJmFf8mXkfMPdf/LXA7cPHtwMfWuJ+GWOWma6f+bdl5jTA8Hf9FStjzFnJsn+9NyL2AHug++dOY8zS0/WV/3BEbAcY/j5SPTEz92bmVGZOLXdZImPM+HQV/93AdcPH1wHfX5ruGGP6Yhyr75vAe4EtEXEQ+CzwOeCuiLge+BXwkXEOlpll0UdlsVUZYiorTmX1dS38WfVdHUvZcps2bSpjVXYewNq1a8tYlaGnzlllVE5PT5exgwcPlrEqM05lzKlxVJmHKkuzstiUddglaxJgx44dZUxlEVaFP1WbqpDoQj5azyv+zPx4EXr/2Ecxxpx1+Bt+xjSKxW9Mo1j8xjSKxW9Mo1j8xjRKrwU8I6LMZFOZVJUtozLflLWlrCG13lpl9ansPHVeav25rgU8q+MpW1TZbypz7+jRo2WsytBTGZrqvBRdrrUaD4WyHJV1q9pVfVH36YEDB0ZuX+qsPmPM7yEWvzGNYvEb0ygWvzGNYvEb0ygWvzGN0rvVV2VuKbtsIfbFGbpkCUJt56mYykZTGXhq/bmua+tVNRO6Zswp+0rZZeq8K1RGmrJ11ThW46EKk7788stlTJ2zKripallU96q6F9X9PS5+5TemUSx+YxrF4jemUSx+YxrF4jemUXqd7VeoGfguSRjKPehKNRut6sGpWWo1A6xmjrvUIFQzx2q2/5JLLiljmzdvLmMqWahCzcCrc+7iEqjls6ol5eaLvfBCvWKdWo6uumbKoakS0BbiAviV35hGsfiNaRSL35hGsfiNaRSL35hGsfiNaZRxluu6Dfhz4EhmvmO47Rbgk8CZIm43Z+YP5tvXOeecUyZ8qESWyq5RySOqdp5aBklZOZWlpOrtKatP2TJdagmqfSq7VI2HskzXr19fxqo+qmumEmrUdVG2XZeahl2TzJRdrWoXVpaeGt/KJlYW8VzGeeX/GnDNiO1fyswrhz/zCt8Yc3Yxr/gz8z6gLuFqjHlTspjP/DdGxEMRcVtE1EuXGmPOSrqK/yvAW4ErgWngC9UTI2JPROyLiH3qc5sxpl86iT8zD2fmTGaeBr4KXCWeuzczpzJzaiFrhxtjlpdO4o+I7bP+/DDwyNJ0xxjTF+NYfd8E3gtsiYiDwGeB90bElUACB4BPjXOwVatW8Za3vGVkTGWPVZaHsgdVpp2y0Z577rkyVmWdqXc06qOOWgpLZbipbK8qe0z1Q1lUyjpStldlpSnL66WXXipjKqvvyJEjZawaD5Vlp1BWpeqjilX399atWxfcRt0bc5lX/Jn58RGbbx37CMaYsxJ/w8+YRrH4jWkUi9+YRrH4jWkUi9+YRum1gOd5553HFVdcMTK2ZcuWsl1lr6iMOWXJHDt2rIw98cQTZezJJ58cuf3EiRNlG2WxqcKZKitRxapsuq5ZccpyVHZZZfWpsVcWrMo8VJZpdd6qH2o8lIXctZBrdd7Kyq7OS2V8zsWv/MY0isVvTKNY/MY0isVvTKNY/MY0isVvTKP0avWtWbOG3bt3j4xV26EuqKgymLoWYdy/f38Zq+whZTWpPirLTvVxw4YNZayyD5VV9uKLL5YxleWo2lU2oCrSqcZKjYdaF7CyvlR2oVpzT42HsuZUrBoTZdtVsYVkK/qV35hGsfiNaRSL35hGsfiNaRSL35hG6XW2/9xzz2Xz5s0jY9u2bSvbVTXmVO05hZr5VjO9v/71r0duV7P9Xeu6qaW8VBLUhReOXkJB7U/Nlh8+fLiMqYSmKllFjf3GjRvLmLrWKlYdTyVVqVl2lSClrqdqV7kman+Vm7UQTfiV35hGsfiNaRSL35hGsfiNaRSL35hGsfiNaZRxluvaCXwduBg4DezNzC9HxCbgW8AuBkt2fTQza5/s//e34E5W9eBUfTlleSg7T9V2qywxtcyUsthUkovap0ouUTZghUpyOXr0aBlT/a8STFSdu4svvriMVctTgV4urUL1o7JL50NZc2qsKttOaaWLjuYyziv/KeAzmfk24N3ADRHxduAm4N7M3A3cO/zbGPMmYV7xZ+Z0Zj4wfHwSeAzYAVwL3D582u3Ah5ark8aYpWdBn/kjYhfwTuB+YFtmTsPgHwRQLylqjDnrGFv8EXE+8B3g05lZf6/zd9vtiYh9EbFPfdY2xvTLWOKPiJUMhP+NzPzucPPhiNg+jG8HRi6Snpl7M3MqM6e6TqQYY5aeecUfg2nFW4HHMvOLs0J3A9cNH18HfH/pu2eMWS7Gyeq7GvgE8HBEPDjcdjPwOeCuiLge+BXwkfl2lJmlBVfZeVDbRmpZJWWtqI8fKqOrqp2nlpnqWitOWX2qTltlfyobSo2j6r/K0KtQtpzK6quyQUHXx6v6qOonqneo6lhqjFV2ZGUHd8kIVW3mMq/4M/PHQGUqvn/sIxljzir8DT9jGsXiN6ZRLH5jGsXiN6ZRLH5jGqXXAp5QWy8qC6+yQlSbrhl/69atK2OXXHLJyO2rVq0q2yirTNleqiiostgqq1LZkSqmrKO1a9eWserclGWnsvp27txZxpQ1V1m+avkvtdSbOme1T1XAs7Ju1XXuYrPOxa/8xjSKxW9Mo1j8xjSKxW9Mo1j8xjSKxW9Mo/Rq9amsvi721YoVK8o2yqJSxQ9VuyoTTGWjqUKRVZYg6CKd1dpuUGdHKstRjWPX7LfqvNU5q/Xz1P2hshKrduqclfWpsi2V5avOrbJFlT2o+jEufuU3plEsfmMaxeI3plEsfmMaxeI3plF6n+2vZl9VDb8qpmrZnThRVxdXMTXLXrVTs8NqBli5BMqRUDPH1ay+mjnukqAD2smoYippRjkSBw8eLGNdEnHUdVGOj1rOTY3Vpk2byljlgKglyipNLGQZL7/yG9MoFr8xjWLxG9MoFr8xjWLxG9MoFr8xjTKv1RcRO4GvAxcDp4G9mfnliLgF+CRwdPjUmzPzB2pfp0+fLmvrKfutsoDUcldPPfVUGdu/f38Ze/bZZ8vY0aNHR25XiSXKelH1ApXdpGrFVRabGitllSk7r4tVqeonKptV1TRUY1VZbMqWU0u2VfUkQY9jVf8R4PLLLx+5XdmDS8E4Pv8p4DOZ+UBErAd+FhH3DGNfysx/XL7uGWOWi3HW6psGpoePT0bEY8CO5e6YMWZ5WdBn/ojYBbwTuH+46caIeCgibouIOrnbGHPWMbb4I+J84DvApzPzBPAV4K3AlQzeGXyhaLcnIvZFxD5VhMIY0y9jiT8iVjIQ/jcy87sAmXk4M2cy8zTwVeCqUW0zc29mTmXmlKriYozpl3nFH4Np21uBxzLzi7O2b5/1tA8Djyx994wxy8U4s/1XA58AHo6IB4fbbgY+HhFXAgkcAD41345Onz5dWnqHDx8u21WZVNPT02Wbxx9/vIypduqjSZVFqDLmVLaispSUzaOy8KradMoOUxaVquGnbMBqrKrls0CPfdc6g1VM9UNdT9VO9VFZnNU7YrW/anzVPTWXcWb7fwyMMm2lp2+MObvxN/yMaRSL35hGsfiNaRSL35hGsfiNaZReC3ieOnWK48ePl7GKKqPr0KFDC24DulCkysKrrC1lNVVZjKCLSKovRKlYtYyT6qPKVFN2kyokWlmLyvpUFpuyI1U/qvNW95vKMFX3lUKNf3Ufqz5W46uWNfudfYz9TGPM7xUWvzGNYvEb0ygWvzGNYvEb0ygWvzGN0qvVNzMzUxaSrCwqqAtkKitEFcdU7dT6f5W9otqoTDVlsSnLUVmElR25devWso3qf5djQV3cU2UrqnXwVB+VDVgVO1Xn1fWaqftKredYnXeX4q+qf3PxK78xjWLxG9MoFr8xjWLxG9MoFr8xjWLxG9MovVp9mdmp8GBlsSmrSWWjqbXuVMZflZHW1bJTGW5q/T+V/Vad9+bNm8s2auy7FvesYqrNli1bypjKjlRjVVl66rooO1JdM3XvqGy7yv5WY19pwlafMWZeLH5jGsXiN6ZRLH5jGsXiN6ZR5p3tj4g1wH3A6uHzv52Zn42Iy4A7gU3AA8AnMrPOvjhzwGIGU83OVzOzamZTzeirmm9q5riacVZJGwuZfZ2NShJRs8rV7LZarksluailwdQ4VuOv6g+uXr26jKnZeZWIUyWMqWumksIuuuiiMqbGSjkqlROgHI7K8Vnq2f5Xgfdl5hUMluO+JiLeDXwe+FJm7gZeAK4f+6jGmIkzr/hzwJl/nyuHPwm8D/j2cPvtwIeWpYfGmGVhrM/8EbFiuELvEeAe4JfA8cw88970ILBjebpojFkOxhJ/Zs5k5pXApcBVwNtGPW1U24jYExH7ImKfWt7YGNMvC5rtz8zjwH8B7wY2RsSZ2btLgZErD2Tm3sycyswpNZFijOmXecUfERdFxMbh4/OAPwEeA34E/MXwadcB31+uThpjlp5xEnu2A7dHxAoG/yzuysx/j4hfAHdGxN8D/w3cOs4BleVRUVkhyg5TiRSqD8pyrGxK9Y5GJbIolO2lbMxqTFQbdSxlAyqqBBiVlNTVnlUJXl2sVnUPqGutkn662MFd7gFlic5lXvFn5kPAO0dsf5rB539jzJsQf8PPmEax+I1pFIvfmEax+I1pFIvfmEaJLtZb54NFHAX+d/jnFuC53g5e4368EffjjbzZ+vEHmVmnHs6iV/G/4cAR+zJzaiIHdz/cD/fDb/uNaRWL35hGmaT4907w2LNxP96I+/FGfm/7MbHP/MaYyeK3/cY0ykTEHxHXRMT/RMRTEXHTJPow7MeBiHg4Ih6MiH09Hve2iDgSEY/M2rYpIu6JiCeHvy+cUD9uiYhnh2PyYER8sId+7IyIH0XEYxHxaET81XB7r2Mi+tHrmETEmoj4SUT8fNiPvxtuvywi7h+Ox7ciolvK5Rkys9cfYAWDMmCXA6uAnwNv77sfw74cALZM4LjvAd4FPDJr2z8ANw0f3wR8fkL9uAX4657HYzvwruHj9cATwNv7HhPRj17HBAjg/OHjlcD9DAro3AV8bLj9n4G/XMxxJvHKfxXwVGY+nYNS33cC106gHxMjM+8Dnp+z+VoGhVChp4KoRT96JzOnM/OB4eOTDIrF7KDnMRH96JUcsOxFcych/h3AM7P+nmTxzwR+GBE/i4g9E+rDGbZl5jQMbkJg6wT7cmNEPDT8WLDsHz9mExG7GNSPuJ8JjsmcfkDPY9JH0dxJiH9UqZFJWQ5XZ+a7gD8DboiI90yoH2cTXwHeymCNhmngC30dOCLOB74DfDozT/R13DH60fuY5CKK5o7LJMR/ENg56++y+Odyk5mHhr+PAN9jspWJDkfEdoDh7yOT6ERmHh7eeKeBr9LTmETESgaC+0Zmfne4ufcxGdWPSY3J8NgLLpo7LpMQ/0+B3cOZy1XAx4C7++5ERKyLiPVnHgMfAB7RrZaVuxkUQoUJFkQ9I7YhH6aHMYlB4blbgccy84uzQr2OSdWPvsekt6K5fc1gzpnN/CCDmdRfAn8zoT5czsBp+DnwaJ/9AL7J4O3j6wzeCV0PbAbuBZ4c/t40oX78K/Aw8BAD8W3voR9/xOAt7EPAg8OfD/Y9JqIfvY4J8IcMiuI+xOAfzd/Oumd/AjwF/BuwejHH8Tf8jGkUf8PPmEax+I1pFIvfmEax+I1pFIvfmEax+I1pFIvfmEax+I1plP8DMQi2q65RHfEAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "output = conv(img.unsqueeze(0))\n", "plt.imshow(output[0, 0].detach(), cmap='gray')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [], "source": [ "conv = nn.Conv2d(3, 1, kernel_size=3, padding=1)\n", "\n", "with torch.no_grad():\n", " conv.weight[:] = torch.tensor([[-1.0, 0.0, 1.0],\n", " [-1.0, 0.0, 1.0],\n", " [-1.0, 0.0, 1.0]])\n", " conv.bias.zero_()" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "output = conv(img.unsqueeze(0))\n", "plt.imshow(output[0, 0].detach(), cmap='gray')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [], "source": [ "pool = nn.MaxPool2d(2)" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 3, 16, 16])" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output = pool(img.unsqueeze(0))\n", "\n", "output.shape" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "ellipsis is not a Module subclass", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 6\u001b[0m \u001b[0mnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTanh\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 7\u001b[0m \u001b[0mnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mMaxPool2d\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 8\u001b[1;33m ...)\n\u001b[0m", "\u001b[1;32m~\\Miniconda3\\envs\\book\\lib\\site-packages\\torch\\nn\\modules\\container.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, *args)\u001b[0m\n\u001b[0;32m 51\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 52\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0midx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 53\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0madd_module\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0midx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmodule\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 54\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 55\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_get_item_by_idx\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0miterator\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0midx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\Miniconda3\\envs\\book\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36madd_module\u001b[1;34m(self, name, module)\u001b[0m\n\u001b[0;32m 171\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mModule\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 172\u001b[0m raise TypeError(\"{} is not a Module subclass\".format(\n\u001b[1;32m--> 173\u001b[1;33m torch.typename(module)))\n\u001b[0m\u001b[0;32m 174\u001b[0m \u001b[1;32melif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_six\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstring_classes\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 175\u001b[0m raise TypeError(\"module name should be a string. Got {}\".format(\n", "\u001b[1;31mTypeError\u001b[0m: ellipsis is not a Module subclass" ] } ], "source": [ "model = nn.Sequential(\n", " nn.Conv2d(3, 16, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " nn.Conv2d(16, 8, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " ...)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " nn.Conv2d(3, 16, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " nn.Conv2d(16, 8, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " # WARNING: something missing here\n", " nn.Linear(512, 32),\n", " nn.Tanh(),\n", " nn.Linear(32, 2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sum([p.numel() for p in model.parameters()])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model(img.unsqueeze(0))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)\n", " self.act1 = nn.Tanh()\n", " self.pool1 = nn.MaxPool2d(2)\n", " self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)\n", " self.act2 = nn.Tanh()\n", " self.pool2 = nn.MaxPool2d(2)\n", " self.fc1 = nn.Linear(8 * 8 * 8, 32)\n", " self.act4 = nn.Tanh()\n", " self.fc2 = nn.Linear(32, 2)\n", "\n", " def forward(self, x):\n", " out = self.pool1(self.act1(self.conv1(x)))\n", " out = self.pool2(self.act2(self.conv2(out)))\n", " out = out.view(-1, 8 * 8 * 8)\n", " out = self.act4(self.fc1(out))\n", " out = self.fc2(out)\n", " return out" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = Net()\n", "\n", "sum([p.numel() for p in model.parameters()])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)\n", " self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)\n", " self.fc1 = nn.Linear(8 * 8 * 8, 32)\n", " self.fc2 = nn.Linear(32, 2)\n", " \n", " def forward(self, x):\n", " out = F.max_pool2d(torch.tanh(self.conv1(x)), 2)\n", " out = F.max_pool2d(torch.tanh(self.conv2(out)), 2)\n", " out = out.view(-1, 8 * 8 * 8)\n", " out = torch.tanh(self.fc1(out))\n", " out = self.fc2(out)\n", " return out" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = Net()\n", "model(img.unsqueeze(0))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n", " shuffle=True)\n", "\n", "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)\n", " self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)\n", " self.fc1 = nn.Linear(8 * 8 * 8, 32)\n", " self.fc2 = nn.Linear(32, 2)\n", " \n", " def forward(self, x):\n", " out = F.max_pool2d(torch.relu(self.conv1(x)), 2)\n", " out = F.max_pool2d(torch.relu(self.conv2(out)), 2)\n", " out = out.view(-1, 8 * 8 * 8)\n", " out = torch.tanh(self.fc1(out))\n", " out = self.fc2(out)\n", " return out\n", " \n", "model = Net()\n", "\n", "learning_rate = 1e-2\n", "\n", "optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n", "\n", "loss_fn = nn.CrossEntropyLoss()\n", "\n", "n_epochs = 100\n", "\n", "for epoch in range(n_epochs):\n", " for imgs, labels in train_loader:\n", " outputs = model(imgs)\n", " loss = loss_fn(outputs, labels)\n", " \n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " print(\"Epoch: %d, Loss: %f\" % (epoch, float(loss)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n", " shuffle=False)\n", "\n", "correct = 0\n", "total = 0\n", "\n", "with torch.no_grad():\n", " for imgs, labels in train_loader:\n", " outputs = model(imgs)\n", " _, predicted = torch.max(outputs, dim=1)\n", " total += labels.shape[0]\n", " correct += int((predicted == labels).sum())\n", " \n", "print(\"Accuracy: %f\" % (correct / total))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64,\n", " shuffle=False)\n", "\n", "correct = 0\n", "total = 0\n", "\n", "with torch.no_grad():\n", " for imgs, labels in val_loader:\n", " outputs = model(imgs)\n", " _, predicted = torch.max(outputs, dim=1)\n", " total += labels.shape[0]\n", " correct += int((predicted == labels).sum())\n", " \n", "print(\"Accuracy: %f\" % (correct / total))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "class Net(nn.Module):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)\n", " self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)\n", " self.fc1 = nn.Linear(8 * 8 * 8, 32)\n", " self.fc2 = nn.Linear(32, 2)\n", " \n", " def forward(self, x):\n", " out = F.max_pool2d(torch.relu(self.conv1(x)), 2)\n", " out = F.max_pool2d(torch.relu(self.conv2(out)), 2)\n", " out = out.view(-1, 8 * 8 * 8)\n", " out = torch.tanh(self.fc1(out))\n", " out = self.fc2(out)\n", " return out\n", " \n", "model = Net()\n", "sum([p.numel() for p in model.parameters()])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = nn.Sequential(\n", " nn.Conv2d(3, 16, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " nn.Conv2d(16, 8, kernel_size=3, padding=1),\n", " nn.Tanh(),\n", " nn.MaxPool2d(2),\n", " nn.Linear(8*8*8, 32),\n", " nn.Tanh(),\n", " nn.Linear(32, 2))\n", "\n", "model(img.unsqueeze(0))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 2 }