|
@@ -0,0 +1,372 @@
|
|
|
|
|
+{
|
|
|
|
|
+ "cells": [
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "code",
|
|
|
|
|
+ "execution_count": 1,
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "outputs": [
|
|
|
|
|
+ {
|
|
|
|
|
+ "data": {
|
|
|
|
|
+ "text/plain": [
|
|
|
|
|
+ "<torch._C.Generator at 0x1301103b0>"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ "execution_count": 1,
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "output_type": "execute_result"
|
|
|
|
|
+ }
|
|
|
|
|
+ ],
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ "import torch\n",
|
|
|
|
|
+ "import torch.nn as nn\n",
|
|
|
|
|
+ "import torch.nn.functional as F\n",
|
|
|
|
|
+ "import torch.optim as optim\n",
|
|
|
|
|
+ "from torch.utils.data import DataLoader, random_split\n",
|
|
|
|
|
+ "from torchvision import datasets\n",
|
|
|
|
|
+ "import torchvision.transforms as transforms\n",
|
|
|
|
|
+ "import matplotlib.pyplot as plt\n",
|
|
|
|
|
+ "%matplotlib inline\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "torch.manual_seed(12046)"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "code",
|
|
|
|
|
+ "execution_count": 2,
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "outputs": [
|
|
|
|
|
+ {
|
|
|
|
|
+ "data": {
|
|
|
|
|
+ "text/plain": [
|
|
|
|
|
+ "(50000, 10000, 10000)"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ "execution_count": 2,
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "output_type": "execute_result"
|
|
|
|
|
+ }
|
|
|
|
|
+ ],
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ "# 准备数据\n",
|
|
|
|
|
+ "dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())\n",
|
|
|
|
|
+ "# 将数据划分成训练集、验证集、测试集\n",
|
|
|
|
|
+ "train_set, val_set = random_split(dataset, [50000, 10000])\n",
|
|
|
|
|
+ "test_set = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())\n",
|
|
|
|
|
+ "len(train_set), len(val_set), len(test_set)"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "code",
|
|
|
|
|
+ "execution_count": 3,
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "outputs": [],
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ "# 构建数据读取器\n",
|
|
|
|
|
+ "train_loader = DataLoader(train_set, batch_size=500, shuffle=True)\n",
|
|
|
|
|
+ "val_loader = DataLoader(val_set, batch_size=500, shuffle=True)\n",
|
|
|
|
|
+ "test_loader = DataLoader(test_set, batch_size=500, shuffle=True)"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "code",
|
|
|
|
|
+ "execution_count": 4,
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "outputs": [
|
|
|
|
|
+ {
|
|
|
|
|
+ "name": "stdout",
|
|
|
|
|
+ "output_type": "stream",
|
|
|
|
|
+ "text": [
|
|
|
|
|
+ "torch.Size([1, 3, 28, 28]) torch.Size([1, 3, 28, 28])\n",
|
|
|
|
|
+ "tensor([7]) torch.Size([1, 4, 4, 4]) torch.Size([1, 4, 4, 4])\n"
|
|
|
|
|
+ ]
|
|
|
|
|
+ }
|
|
|
|
|
+ ],
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ "# 卷积层的几个小技巧\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "# 这个卷积操作,输入和输出的形状是一样的\n",
|
|
|
|
|
+ "conv3 = nn.Conv2d(3, 3, (3, 3), stride=1, padding=1)\n",
|
|
|
|
|
+ "x = torch.randn(1, 3, 28, 28)\n",
|
|
|
|
|
+ "print(x.size(), conv3(x).size())\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "# 这两个卷积操作输出的形状是一样的\n",
|
|
|
|
|
+ "stride = torch.randint(0, 10, (1,))\n",
|
|
|
|
|
+ "conv1 = nn.Conv2d(3, 4, (3, 3), stride=stride, padding=1)\n",
|
|
|
|
|
+ "conv2 = nn.Conv2d(3, 4, (1, 1), stride=stride, padding=0)\n",
|
|
|
|
|
+ "x = torch.randn(1, 3, 28, 28)\n",
|
|
|
|
|
+ "print(stride, conv1(x).size(), conv2(x).size())"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "code",
|
|
|
|
|
+ "execution_count": 5,
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "outputs": [
|
|
|
|
|
+ {
|
|
|
|
|
+ "name": "stdout",
|
|
|
|
|
+ "output_type": "stream",
|
|
|
|
|
+ "text": [
|
|
|
|
|
+ "torch.Size([1, 3, 28, 28])\n"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "ename": "RuntimeError",
|
|
|
|
|
+ "evalue": "The size of tensor a (14) must match the size of tensor b (28) at non-singleton dimension 3",
|
|
|
|
|
+ "output_type": "error",
|
|
|
|
|
+ "traceback": [
|
|
|
|
|
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
|
|
|
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
|
|
|
|
+ "\u001b[0;32m<ipython-input-5-4d7bea514a78>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;31m# 报错\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0mm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mResidualBlockBugVersion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m28\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m28\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
|
|
|
+ "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1499\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1500\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1502\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1503\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
|
|
|
+ "\u001b[0;32m<ipython-input-5-4d7bea514a78>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;31m## 如果stride != 1 or in_channel != out_channel,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;31m## 下面的计算会出错,因为out和inputs的形状不一样\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 22\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
|
|
|
+ "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (14) must match the size of tensor b (28) at non-singleton dimension 3"
|
|
|
|
|
+ ]
|
|
|
|
|
+ }
|
|
|
|
|
+ ],
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ "class ResidualBlockBugVersion(nn.Module):\n",
|
|
|
|
|
+ " \n",
|
|
|
|
|
+ " def __init__(self, in_channel, out_channel, stride=1):\n",
|
|
|
|
|
+ " super().__init__()\n",
|
|
|
|
|
+ " self.conv1 = nn.Conv2d(\n",
|
|
|
|
|
+ " in_channel, out_channel, (3, 3), \n",
|
|
|
|
|
+ " stride=stride, padding=1, bias=False)\n",
|
|
|
|
|
+ " self.bn1 = nn.BatchNorm2d(out_channel)\n",
|
|
|
|
|
+ " self.conv2 = nn.Conv2d(\n",
|
|
|
|
|
+ " out_channel, out_channel, (3, 3),\n",
|
|
|
|
|
+ " stride=1, padding=1, bias=False)\n",
|
|
|
|
|
+ " self.bn2 = nn.BatchNorm2d(out_channel)\n",
|
|
|
|
|
+ " \n",
|
|
|
|
|
+ " def forward(self, x):\n",
|
|
|
|
|
+ " inputs = x\n",
|
|
|
|
|
+ " out = F.relu(self.bn1(self.conv1(x)))\n",
|
|
|
|
|
+ " out = self.bn2(self.conv2(out))\n",
|
|
|
|
|
+ " # 残差连接\n",
|
|
|
|
|
+ " ## 如果stride != 1 or in_channel != out_channel,\n",
|
|
|
|
|
+ " ## 下面的计算会出错,因为out和inputs的形状不一样\n",
|
|
|
|
|
+ " out += inputs\n",
|
|
|
|
|
+ " out = F.relu(out)\n",
|
|
|
|
|
+ " return out\n",
|
|
|
|
|
+ " \n",
|
|
|
|
|
+ "m = ResidualBlockBugVersion(3, 3)\n",
|
|
|
|
|
+ "print(m(torch.randn(1, 3, 28, 28)).size())\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "# 报错\n",
|
|
|
|
|
+ "m = ResidualBlockBugVersion(3, 3, 2)\n",
|
|
|
|
|
+ "print(m(torch.randn(1, 3, 28, 28)).size())"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "code",
|
|
|
|
|
+ "execution_count": 6,
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "outputs": [
|
|
|
|
|
+ {
|
|
|
|
|
+ "name": "stdout",
|
|
|
|
|
+ "output_type": "stream",
|
|
|
|
|
+ "text": [
|
|
|
|
|
+ "torch.Size([1, 4, 14, 14])\n"
|
|
|
|
|
+ ]
|
|
|
|
|
+ }
|
|
|
|
|
+ ],
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ "class ResidualBlock(nn.Module):\n",
|
|
|
|
|
+ " \n",
|
|
|
|
|
+ " def __init__(self, in_channel, out_channel, stride=1):\n",
|
|
|
|
|
+ " super().__init__()\n",
|
|
|
|
|
+ " self.conv1 = nn.Conv2d(\n",
|
|
|
|
|
+ " in_channel, out_channel, (3, 3), \n",
|
|
|
|
|
+ " stride=stride, padding=1, bias=False)\n",
|
|
|
|
|
+ " self.bn1 = nn.BatchNorm2d(out_channel)\n",
|
|
|
|
|
+ " self.conv2 = nn.Conv2d(\n",
|
|
|
|
|
+ " out_channel, out_channel, (3, 3),\n",
|
|
|
|
|
+ " stride=1, padding=1, bias=False)\n",
|
|
|
|
|
+ " self.bn2 = nn.BatchNorm2d(out_channel)\n",
|
|
|
|
|
+ " # 让输入的形状和输出的形状一样\n",
|
|
|
|
|
+ " self.downsample = None\n",
|
|
|
|
|
+ " if stride != 1 or in_channel != out_channel:\n",
|
|
|
|
|
+ " self.downsample = nn.Sequential(\n",
|
|
|
|
|
+ " nn.Conv2d(in_channel, out_channel, (1, 1), stride=stride, bias=False),\n",
|
|
|
|
|
+ " nn.BatchNorm2d(out_channel))\n",
|
|
|
|
|
+ " \n",
|
|
|
|
|
+ " def forward(self, x):\n",
|
|
|
|
|
+ " inputs = x\n",
|
|
|
|
|
+ " out = F.relu(self.bn1(self.conv1(x)))\n",
|
|
|
|
|
+ " out = self.bn2(self.conv2(out))\n",
|
|
|
|
|
+ " # 让输入(inputs)的形状和输出(out)的形状一样\n",
|
|
|
|
|
+ " if self.downsample is not None:\n",
|
|
|
|
|
+ " inputs = self.downsample(inputs)\n",
|
|
|
|
|
+ " out += inputs\n",
|
|
|
|
|
+ " out = F.relu(out)\n",
|
|
|
|
|
+ " return out\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "m = ResidualBlock(3, 4, 2)\n",
|
|
|
|
|
+ "print(m(torch.randn(1, 3, 28, 28)).size())"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "code",
|
|
|
|
|
+ "execution_count": 7,
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "outputs": [],
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ "class ResNet(nn.Module):\n",
|
|
|
|
|
+ " \n",
|
|
|
|
|
+ " def __init__(self):\n",
|
|
|
|
|
+ " super().__init__()\n",
|
|
|
|
|
+ " self.block1 = ResidualBlock(1, 20)\n",
|
|
|
|
|
+ " self.block2 = ResidualBlock(20, 40, stride=2)\n",
|
|
|
|
|
+ " self.block3 = ResidualBlock(40, 60, stride=2)\n",
|
|
|
|
|
+ " self.block4 = ResidualBlock(60, 80, stride=2)\n",
|
|
|
|
|
+ " self.block5 = ResidualBlock(80, 100, stride=2)\n",
|
|
|
|
|
+ " self.block6 = ResidualBlock(100, 120, stride=2)\n",
|
|
|
|
|
+ " self.lm = nn.Linear(120, 10)\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ " def forward(self, x):\n",
|
|
|
|
|
+ " x = self.block1(x) # (B, 20, 28, 28)\n",
|
|
|
|
|
+ " x = self.block2(x) # (B, 40, 14, 14)\n",
|
|
|
|
|
+ " x = self.block3(x) # (B, 60, 7, 7)\n",
|
|
|
|
|
+ " x = self.block4(x) # (B, 60, 4, 4)\n",
|
|
|
|
|
+ " x = self.block5(x) # (B, 60, 2, 2)\n",
|
|
|
|
|
+ " x = self.block6(x) # (B, 120, 1, 1)\n",
|
|
|
|
|
+ " out = self.lm(x.view(x.shape[0], -1))\n",
|
|
|
|
|
+ " return out\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "model = ResNet()"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "code",
|
|
|
|
|
+ "execution_count": 8,
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "outputs": [],
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ "eval_iters = 10\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "def estimate_loss(model):\n",
|
|
|
|
|
+ " re = {}\n",
|
|
|
|
|
+ " # 将模型切换至评估模式\n",
|
|
|
|
|
+ " model.eval()\n",
|
|
|
|
|
+ " re['train'] = _loss(model, train_loader)\n",
|
|
|
|
|
+ " re['val'] = _loss(model, val_loader)\n",
|
|
|
|
|
+ " re['test'] = _loss(model, test_loader)\n",
|
|
|
|
|
+ " # 将模型切换至训练模式\n",
|
|
|
|
|
+ " model.train()\n",
|
|
|
|
|
+ " return re\n",
|
|
|
|
|
+ "\n",
|
|
|
|
|
+ "@torch.no_grad()\n",
|
|
|
|
|
+ "def _loss(model, data_loader):\n",
|
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
|
+ " 计算模型在不同数据集下面的评估指标\n",
|
|
|
|
|
+ " \"\"\"\n",
|
|
|
|
|
+ " loss = []\n",
|
|
|
|
|
+ " accuracy = []\n",
|
|
|
|
|
+ " data_iter = iter(data_loader)\n",
|
|
|
|
|
+ " for k in range(eval_iters):\n",
|
|
|
|
|
+ " inputs, labels = next(data_iter)\n",
|
|
|
|
|
+ " B = inputs.shape[0]\n",
|
|
|
|
|
+ " logits = model(inputs)\n",
|
|
|
|
|
+ " # 计算模型损失\n",
|
|
|
|
|
+ " loss.append(F.cross_entropy(logits, labels))\n",
|
|
|
|
|
+ " # 计算预测的准确率\n",
|
|
|
|
|
+ " _, predicted = torch.max(logits, 1)\n",
|
|
|
|
|
+ " accuracy.append((predicted == labels).sum() / B)\n",
|
|
|
|
|
+ " re = {\n",
|
|
|
|
|
+ " 'loss': torch.tensor(loss).mean().item(),\n",
|
|
|
|
|
+ " 'accuracy': torch.tensor(accuracy).mean().item()\n",
|
|
|
|
|
+ " }\n",
|
|
|
|
|
+ " return re"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "code",
|
|
|
|
|
+ "execution_count": 9,
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "outputs": [],
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ "def train_resnet(model, optimizer, data_loader, epochs=10, penalty=[]):\n",
|
|
|
|
|
+ " lossi = []\n",
|
|
|
|
|
+ " for epoch in range(epochs):\n",
|
|
|
|
|
+ " for i, data in enumerate(data_loader, 0):\n",
|
|
|
|
|
+ " inputs, labels = data\n",
|
|
|
|
|
+ " optimizer.zero_grad()\n",
|
|
|
|
|
+ " logits = model(inputs)\n",
|
|
|
|
|
+ " loss = F.cross_entropy(logits, labels)\n",
|
|
|
|
|
+ " lossi.append(loss.item())\n",
|
|
|
|
|
+ " # 增加惩罚项\n",
|
|
|
|
|
+ " for p in penalty:\n",
|
|
|
|
|
+ " loss += p(model)\n",
|
|
|
|
|
+ " loss.backward()\n",
|
|
|
|
|
+ " optimizer.step()\n",
|
|
|
|
|
+ " # 评估模型,并输出结果\n",
|
|
|
|
|
+ " stats = estimate_loss(model)\n",
|
|
|
|
|
+ " train_loss = f'train loss {stats[\"train\"][\"loss\"]:.4f}'\n",
|
|
|
|
|
+ " val_loss = f'val loss {stats[\"val\"][\"loss\"]:.4f}'\n",
|
|
|
|
|
+ " test_loss = f'test loss {stats[\"test\"][\"loss\"]:.4f}'\n",
|
|
|
|
|
+ " print(f'epoch {epoch:>2}: {train_loss}, {val_loss}, {test_loss}')\n",
|
|
|
|
|
+ " train_acc = f'train acc {stats[\"train\"][\"accuracy\"]:.4f}'\n",
|
|
|
|
|
+ " val_acc = f'val acc {stats[\"val\"][\"accuracy\"]:.4f}'\n",
|
|
|
|
|
+ " test_acc = f'test acc {stats[\"test\"][\"accuracy\"]:.4f}'\n",
|
|
|
|
|
+ " print(f'{\"\":>10}{train_acc}, {val_acc}, {test_acc}')\n",
|
|
|
|
|
+ " return lossi"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "code",
|
|
|
|
|
+ "execution_count": 10,
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "outputs": [],
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ "stats = {}"
|
|
|
|
|
+ ]
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ "cell_type": "code",
|
|
|
|
|
+ "execution_count": 11,
|
|
|
|
|
+ "metadata": {},
|
|
|
|
|
+ "outputs": [
|
|
|
|
|
+ {
|
|
|
|
|
+ "name": "stdout",
|
|
|
|
|
+ "output_type": "stream",
|
|
|
|
|
+ "text": [
|
|
|
|
|
+ "epoch 0: train loss 0.0820, val loss 0.1000, test loss 0.0955\n",
|
|
|
|
|
+ " train acc 0.9728, val acc 0.9646, test acc 0.9686\n",
|
|
|
|
|
+ "epoch 1: train loss 0.0826, val loss 0.0873, test loss 0.0864\n",
|
|
|
|
|
+ " train acc 0.9746, val acc 0.9744, test acc 0.9734\n",
|
|
|
|
|
+ "epoch 2: train loss 0.0523, val loss 0.0689, test loss 0.0634\n",
|
|
|
|
|
+ " train acc 0.9860, val acc 0.9798, test acc 0.9804\n",
|
|
|
|
|
+ "epoch 3: train loss 0.0781, val loss 0.1083, test loss 0.0830\n",
|
|
|
|
|
+ " train acc 0.9762, val acc 0.9700, test acc 0.9754\n",
|
|
|
|
|
+ "epoch 4: train loss 0.0157, val loss 0.0412, test loss 0.0365\n",
|
|
|
|
|
+ " train acc 0.9930, val acc 0.9870, test acc 0.9894\n"
|
|
|
|
|
+ ]
|
|
|
|
|
+ }
|
|
|
|
|
+ ],
|
|
|
|
|
+ "source": [
|
|
|
|
|
+ "stats['resnet'] = train_resnet(model, optim.Adam(model.parameters(), lr=0.01), train_loader, epochs=5)"
|
|
|
|
|
+ ]
|
|
|
|
|
+ }
|
|
|
|
|
+ ],
|
|
|
|
|
+ "metadata": {
|
|
|
|
|
+ "kernelspec": {
|
|
|
|
|
+ "display_name": "Python 3",
|
|
|
|
|
+ "language": "python",
|
|
|
|
|
+ "name": "python3"
|
|
|
|
|
+ },
|
|
|
|
|
+ "language_info": {
|
|
|
|
|
+ "codemirror_mode": {
|
|
|
|
|
+ "name": "ipython",
|
|
|
|
|
+ "version": 3
|
|
|
|
|
+ },
|
|
|
|
|
+ "file_extension": ".py",
|
|
|
|
|
+ "mimetype": "text/x-python",
|
|
|
|
|
+ "name": "python",
|
|
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
|
|
+ "version": "3.8.5"
|
|
|
|
|
+ }
|
|
|
|
|
+ },
|
|
|
|
|
+ "nbformat": 4,
|
|
|
|
|
+ "nbformat_minor": 4
|
|
|
|
|
+}
|