{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n", "\u001b[0mRequirement already satisfied: torchvision in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (0.15.2)\n", "Requirement already satisfied: numpy in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torchvision) (1.24.4)\n", "Requirement already satisfied: requests in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torchvision) (2.24.0)\n", "Requirement already satisfied: torch==2.0.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torchvision) (2.0.1)\n", "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torchvision) (8.0.1)\n", "Requirement already satisfied: filelock in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch==2.0.1->torchvision) (3.0.12)\n", "Requirement already satisfied: typing-extensions in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch==2.0.1->torchvision) (4.8.0)\n", "Requirement already satisfied: sympy in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch==2.0.1->torchvision) (1.6.2)\n", "Requirement already satisfied: networkx in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch==2.0.1->torchvision) (2.5)\n", "Requirement already satisfied: jinja2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch==2.0.1->torchvision) (2.11.2)\n", "Requirement already satisfied: chardet<4,>=3.0.2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->torchvision) (3.0.4)\n", "Requirement already satisfied: idna<3,>=2.5 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->torchvision) (2.10)\n", "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->torchvision) (1.25.11)\n", "Requirement already satisfied: certifi>=2017.4.17 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->torchvision) (2020.6.20)\n", "Requirement already satisfied: MarkupSafe>=0.23 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from jinja2->torch==2.0.1->torchvision) (1.1.1)\n", "Requirement already satisfied: decorator>=4.3.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from networkx->torch==2.0.1->torchvision) (4.4.2)\n", "Requirement already satisfied: mpmath>=0.19 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from sympy->torch==2.0.1->torchvision) (1.1.0)\n", "\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33mDEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063\u001b[0m\u001b[33m\n", "\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" ] } ], "source": [ "!pip install torchvision" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader, random_split\n", "from torchvision import datasets\n", "import torchvision.transforms as transforms\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "torch.manual_seed(12046)\n", "\n", "dataset = datasets.MNIST(root='./mnist', train=True, download=True, transform=transforms.ToTensor())" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 28, 28])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x, y = dataset[21]\n", "x.shape" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/OQEPoAAAACXBIWXMAAAsTAAALEwEAmpwYAAAOCElEQVR4nO3da6hd9ZnH8d9vYosQRdQcY0jFUy8YZHBsOYSBBuOFKcY3SYNIfdEkIqZIxAs1jmaEegERnbaMMhRPR0kcO5ZiNYpIp5lQL31TcpRMEm8TRyIaojkhStQX6XjyzIuzLEc9+7+Oe+3byfP9wGHvvZ699nqyk1/WPuu/1v47IgTg6Pc3/W4AQG8QdiAJwg4kQdiBJAg7kMQxvdzYvHnzYnh4uJebBFLZs2ePDhw44OlqjcJu+1JJ/yJpjqR/i4h7S88fHh7W2NhYk00CKBgZGWlZa/tjvO05kv5V0jJJ50q60va57b4egO5q8jv7YklvRcTbEfEXSb+RtLwzbQHotCZhXyjp3SmP36uWfYHttbbHbI+Nj4832ByAJrp+ND4iRiNiJCJGhoaGur05AC00CfteSadNefytahmAAdQk7NsknW3727a/KemHkp7pTFsAOq3tobeI+Mz2dZL+U5NDb49ExKsd6wxARzUaZ4+I5yQ916FeAHQRp8sCSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kERPp2wGpnr++eeL9YsvvrhYj4i2X3/p0qXFdY9G7NmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnG2dFVGzdubFl74IEHiuvOmTOnWJ+YmCjWb7rpppa11atXF9ddt25dsX7MMbMvOo06tr1H0seSJiR9FhEjnWgKQOd14r+niyLiQAdeB0AX8Ts7kETTsIekP9h+2fba6Z5ge63tMdtj4+PjDTcHoF1Nw74kIr4raZmkdbYv+PITImI0IkYiYmRoaKjh5gC0q1HYI2Jvdbtf0lOSFneiKQCd13bYbc+1ffzn9yV9X9KuTjUGoLOaHI2fL+kp25+/zn9ExO870hVmjdI4uiQ9+uijLWs7d+7scDczf/2bb765uO6KFSuK9dNPP72dlvqq7bBHxNuS/q6DvQDoIobegCQIO5AEYQeSIOxAEoQdSGL2XaeHr+Wjjz4q1rdv316sX3XVVcV63SnQhw8fLtZLFi1aVKzXXeK6e/futrd9NGLPDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM5+FNi8eXPL2ujoaHHdLVu2FOt1Y9l1X/fcxPr164v1I0eOFOvXXHNNJ9uZ9dizA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASjLPPAo899lixvmrVqq5tOyKK9bpx+G5uu043e5uN2LMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMsw+AunH0G264oVgvXVN+7LHHFtc95ZRTivVPPvmkWD948GCxXlLX2/HHH1+sHzp0qFjv5rX2s1Htnt32I7b32941ZdlJtrfY3l3dntjdNgE0NZOP8RslXfqlZbdK2hoRZ0vaWj0GMMBqwx4RL0r68me15ZI2Vfc3SVrR2bYAdFq7B+jmR8S+6v77kua3eqLttbbHbI/VzQsGoHsaH42PyasVWl6xEBGjETESESNDQ0NNNwegTe2G/QPbCySput3fuZYAdEO7YX9G0urq/mpJT3emHQDdUjvObvtxSRdKmmf7PUk/lXSvpN/avlrSO5Ku6GaTs13pe92l+uvRm4wXL168uFjfunVrsb5x48Zivcl3s99zzz3F+sqVK4v1ut7wRbVhj4grW5Qu6XAvALqI02WBJAg7kARhB5Ig7EAShB1IgktcO6BuCOjGG29s9Pp1l4KWhtcefPDBRtuuc9555xXra9asaVm79tprG2378ssvL9ZL01Vv27at0bZnI/bsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+wdcNdddxXrn376aaPX37BhQ7F+2223NXr9kiVLlhTry5YtK9bnz2/5jWWNHXfcccV63fkJ2bBnB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkGGefoe3bt7es1U1rPDExUawfOXKknZZ64qyzzup3C22bnKxoenV/J0cj9uxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATj7JVdu3YV66Xpgz/88MPiuk2mXEZrdec3HD58uGUt499J7Z7d9iO299veNWXZHbb32t5e/VzW3TYBNDWTj/EbJV06zfJfRMT51c9znW0LQKfVhj0iXpR0sAe9AOiiJgforrO9o/qYf2KrJ9lea3vM9tj4+HiDzQFoot2w/1LSmZLOl7RP0s9aPTEiRiNiJCJGhoaG2twcgKbaCntEfBARExFxRNKvJLWeRhTAQGgr7LYXTHn4A0nlcSsAfVc7zm77cUkXSppn+z1JP5V0oe3zJYWkPZJ+3L0We+P6668v1t99990edYKZeuKJJ4r1jHOwl9SGPSKunGbxw13oBUAXcboskARhB5Ig7EAShB1IgrADSXCJaw/cd999/W5hVnrjjTeK9VtuuaXt1x4eHi7Wj8bpntmzA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASjLP3wMknn9zvFgZS3Tj68uXLi/UDBw4U6/Pnz29Zq7s8trTubMWeHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYJy9EhHF+sTERNuvvWbNmmJ91apVbb92v9VNm1z6s23evLnRts8888xi/dlnn21ZO+eccxptezZizw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTDOXrn99tuL9R07drSsHTp0qNG2L7roomLddrFeuu67bjy57jvt684/OHz4cLFemjZ57ty5xXU3bNhQrK9cubJYzziWXlK7Z7d9mu0/2n7N9qu2b6iWn2R7i+3d1e2J3W8XQLtm8jH+M0k/iYhzJf29pHW2z5V0q6StEXG2pK3VYwADqjbsEbEvIl6p7n8s6XVJCyUtl7SpetomSSu61COADvhaB+hsD0v6jqQ/S5ofEfuq0vuSpv3SLttrbY/ZHhsfH2/SK4AGZhx228dJ+p2kGyPiC0ekYvIozrRHciJiNCJGImJkaGioUbMA2jejsNv+hiaD/uuIeLJa/IHtBVV9gaT93WkRQCfUDr15ctznYUmvR8TPp5SekbRa0r3V7dNd6bBHLrnkkmL9ySefbFmrGwKqG5p74YUXivU5c+YU6y+99FKx3kTdpb11vV1wwQUta6tXry6uO5sv/R1EMxln/56kH0naaXt7tWyDJkP+W9tXS3pH0hVd6RBAR9SGPSL+JKnVWR3l3SGAgcHpskAShB1IgrADSRB2IAnCDiTBJa4ztHTp0pa10uWvkjQ6Olqs33333W311AunnnpqsV4aR5ekhx56qGXthBNOaKsntIc9O5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kwTh7ByxcuLBYv/POO4v1M844o1i///77i/U333yzZW3RokXFddevX1+s1/W2ZMmSYh2Dgz07kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBOPsAqPv+9Lo6MBPs2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgidqw2z7N9h9tv2b7Vds3VMvvsL3X9vbq57LutwugXTM5qeYzST+JiFdsHy/pZdtbqtovIuKfu9cegE6Zyfzs+yTtq+5/bPt1SeWvZgEwcL7W7+y2hyV9R9Kfq0XX2d5h+xHbJ7ZYZ63tMdtj4+PjzboF0LYZh932cZJ+J+nGiDgk6ZeSzpR0vib3/D+bbr2IGI2IkYgYGRoaat4xgLbMKOy2v6HJoP86Ip6UpIj4ICImIuKIpF9JWty9NgE0NZOj8Zb0sKTXI+LnU5YvmPK0H0ja1fn2AHTKTI7Gf0/SjyTttL29WrZB0pW2z5cUkvZI+nEX+gPQITM5Gv8nSZ6m9Fzn2wHQLZxBByRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSMIR0buN2eOS3pmyaJ6kAz1r4OsZ1N4GtS+J3trVyd5Oj4hpv/+tp2H/ysbtsYgY6VsDBYPa26D2JdFbu3rVGx/jgSQIO5BEv8M+2uftlwxqb4Pal0Rv7epJb339nR1A7/R7zw6gRwg7kERfwm77Uttv2n7L9q396KEV23ts76ymoR7rcy+P2N5ve9eUZSfZ3mJ7d3U77Rx7feptIKbxLkwz3tf3rt/Tn/f8d3bbcyT9j6R/kPSepG2SroyI13raSAu290gaiYi+n4Bh+wJJn0h6NCL+tlp2n6SDEXFv9R/liRHxjwPS2x2SPun3NN7VbEULpk4zLmmFpDXq43tX6OsK9eB968eefbGktyLi7Yj4i6TfSFrehz4GXkS8KOnglxYvl7Spur9Jk/9Yeq5FbwMhIvZFxCvV/Y8lfT7NeF/fu0JfPdGPsC+U9O6Ux+9psOZ7D0l/sP2y7bX9bmYa8yNiX3X/fUnz+9nMNGqn8e6lL00zPjDvXTvTnzfFAbqvWhIR35W0TNK66uPqQIrJ38EGaex0RtN498o004z/VT/fu3anP2+qH2HfK+m0KY+/VS0bCBGxt7rdL+kpDd5U1B98PoNudbu/z/381SBN4z3dNOMagPeun9Of9yPs2ySdbfvbtr8p6YeSnulDH19he2514ES250r6vgZvKupnJK2u7q+W9HQfe/mCQZnGu9U04+rze9f36c8jouc/ki7T5BH5/5X0T/3ooUVfZ0j67+rn1X73JulxTX6s+z9NHtu4WtLJkrZK2i3pvySdNEC9/buknZJ2aDJYC/rU2xJNfkTfIWl79XNZv9+7Ql89ed84XRZIggN0QBKEHUiCsANJEHYgCcIOJEHYgSQIO5DE/wOeBksu2CwdCAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.imshow(x.squeeze(0).numpy(), cmap=plt.cm.binary)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "train_set, val_set = random_split(dataset, [50000, 10000])\n", "test_set = datasets.MNIST(root='./mnist', train=False, download=True, transform=transforms.ToTensor())" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "train_loader = DataLoader(train_set, batch_size=500, shuffle=True)\n", "val_loader = DataLoader(val_set, batch_size=500, shuffle=True)\n", "test_loader = DataLoader(test_set, batch_size=500, shuffle=True)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([500, 1, 28, 28]), torch.Size([500]), torch.Size([500, 784]))" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x, y = next(iter(train_loader))\n", "x.shape, y.shape, x.view(x.shape[0], -1).shape" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "class MLP(nn.Module):\n", " \n", " def __init__(self):\n", " pass\n", " \n", " def forward(self, x):\n", " pass\n", " \n", "model = MLP()\n", "\n", "model = nn.Sequential(\n", " nn.Linear(784, 30), nn.Sigmoid(),\n", " nn.Linear( 30, 20), nn.Sigmoid(),\n", " nn.Linear( 20, 10)\n", ")" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "eval_iters = 10\n", "\n", "\n", "def estimate_loss(model):\n", " re = {}\n", " # 将模型切换为评估模式\n", " model.eval()\n", " re['train'] = _loss(model, train_loader)\n", " re['val'] = _loss(model, val_loader)\n", " re['test'] = _loss(model, test_loader)\n", " # 将模型切换为训练模式\n", " model.train()\n", " return re\n", "\n", " \n", "@torch.no_grad()\n", "def _loss(model, dataloader):\n", " # 估算模型效果\n", " loss = []\n", " acc = []\n", " data_iter = iter(dataloader)\n", " for t in range(eval_iters):\n", " inputs, labels = next(data_iter)\n", " # inputs: (500, 1, 28, 28)\n", " # labels: (500)\n", " B, C, H, W = inputs.shape\n", " logits = model(inputs.view(B, -1))\n", " loss.append(F.cross_entropy(logits, labels))\n", " # preds = torch.argmax(F.softmax(logits, dim=-1), dim=-1)\n", " preds = torch.argmax(logits, dim=-1)\n", " acc.append((preds == labels).sum() / B)\n", " re = {\n", " 'loss': torch.tensor(loss).mean().item(),\n", " 'acc': torch.tensor(acc).mean().item()\n", " }\n", " return re" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([3, 3, 3, 2, 3, 1]), tensor([1, 1, 0, 2, 3, 3]))" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds = torch.randint(4, (6, ))\n", "labels = torch.randint(4, (6, ))\n", "preds, labels" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(2)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(preds == labels).sum()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'train': {'loss': 2.346519947052002, 'acc': 0.1096000075340271},\n", " 'val': {'loss': 2.345228672027588, 'acc': 0.10600000619888306},\n", " 'test': {'loss': 2.342650890350342, 'acc': 0.11500000953674316}}" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "estimate_loss(model)" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "def train_model(model, optimizer, epochs=10, penalty=False):\n", " lossi = []\n", " for e in range(epochs):\n", " for data in train_loader:\n", " inputs, labels = data\n", " B, C, H, W = inputs.shape\n", " logits = model(inputs.view(B, -1))\n", " loss = F.cross_entropy(logits, labels)\n", " lossi.append(loss.item())\n", " if penalty:\n", " w = torch.cat([p.view(-1) for p in model.parameters()])\n", " loss += 0.001 * w.abs().sum() + 0.002 * w.square().sum()\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " stats = estimate_loss(model)\n", " train_loss = f'{stats[\"train\"][\"loss\"]:.3f}'\n", " val_loss = f'{stats[\"val\"][\"loss\"]:.3f}'\n", " test_loss = f'{stats[\"test\"][\"loss\"]:.3f}'\n", " print(f'epoch {e} train {train_loss} val {val_loss} test {test_loss}')\n", " return lossi" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "loss = {}" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0 train 2.317 val 2.316 test 2.313\n", "epoch 1 train 2.306 val 2.304 test 2.304\n", "epoch 2 train 2.301 val 2.303 test 2.302\n", "epoch 3 train 2.300 val 2.300 test 2.301\n", "epoch 4 train 2.299 val 2.301 test 2.301\n", "epoch 5 train 2.299 val 2.300 test 2.299\n", "epoch 6 train 2.298 val 2.300 test 2.298\n", "epoch 7 train 2.298 val 2.299 test 2.298\n", "epoch 8 train 2.298 val 2.299 test 2.299\n", "epoch 9 train 2.296 val 2.297 test 2.297\n" ] } ], "source": [ "model = nn.Sequential(\n", " nn.Linear(784, 30), nn.Sigmoid(),\n", " nn.Linear( 30, 20), nn.Sigmoid(),\n", " nn.Linear( 20, 10)\n", ")\n", "\n", "loss['mlp'] = train_model(model, optim.SGD(model.parameters(), lr=0.01))" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0 train 2.286 val 2.287 test 2.287\n", "epoch 1 train 2.249 val 2.248 test 2.246\n", "epoch 2 train 2.191 val 2.189 test 2.187\n", "epoch 3 train 2.110 val 2.106 test 2.104\n", "epoch 4 train 1.976 val 1.974 test 1.971\n", "epoch 5 train 1.786 val 1.781 test 1.783\n", "epoch 6 train 1.539 val 1.526 test 1.524\n", "epoch 7 train 1.277 val 1.278 test 1.271\n", "epoch 8 train 1.072 val 1.072 test 1.076\n", "epoch 9 train 0.955 val 0.936 test 0.947\n" ] } ], "source": [ "model = nn.Sequential(\n", " nn.Linear(784, 30), nn.ReLU(),\n", " nn.Linear( 30, 20), nn.ReLU(),\n", " nn.Linear( 20, 10)\n", ")\n", "\n", "loss['mlp_relu'] = train_model(model, optim.SGD(model.parameters(), lr=0.01))" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0 train 1.449 val 1.457 test 1.450\n", "epoch 1 train 1.072 val 1.074 test 1.054\n", "epoch 2 train 0.801 val 0.803 test 0.790\n", "epoch 3 train 0.617 val 0.637 test 0.615\n", "epoch 4 train 0.511 val 0.514 test 0.502\n", "epoch 5 train 0.420 val 0.435 test 0.430\n", "epoch 6 train 0.370 val 0.397 test 0.369\n", "epoch 7 train 0.332 val 0.340 test 0.332\n", "epoch 8 train 0.317 val 0.319 test 0.308\n", "epoch 9 train 0.279 val 0.298 test 0.280\n" ] } ], "source": [ "model = nn.Sequential(\n", " nn.Linear(784, 30, bias=False), nn.LayerNorm(30), nn.ReLU(),\n", " nn.Linear( 30, 20, bias=False), nn.LayerNorm(20), nn.ReLU(),\n", " nn.Linear( 20, 10)\n", ")\n", "\n", "loss['mlp_relu_layer'] = train_model(model, optim.SGD(model.parameters(), lr=0.01))" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/OQEPoAAAACXBIWXMAAAsTAAALEwEAmpwYAAA6MUlEQVR4nO3dd3hUVfrA8e+Zkkw6pAAhARKadEOHxUIVRKqiqyKCuiKuLPJT17UXVrFhF7EgBhQsqKyIBZQiIgKGKr2XUEIIENIz5fz+uEMIAtKS3MzM+3meeZK5czPz3rmTd84999z3KK01QgghfJ/F7ACEEEKUDUnoQgjhJyShCyGEn5CELoQQfkISuhBC+AmbWS8cGxurk5KSzHp5IYTwScuXLz+ktY473WOmJfSkpCTS0tLMenkhhPBJSqldZ3pMulyEEMJPSEIXQgg/IQldCCH8hCR0IYTwE5LQhRDCT0hCF0IIPyEJXQgh/IRp49Av1KYDOXy7Zh9WiwWbVeGwW2mRGEWLxCiCbVazwxNCCNP4XELfejCXN+ZtPWV5kM1C84QoQoNOJPWwIBsRDhtRIXY6X1KNTvVjUEpVZLhCCFFhlFkTXLRp00ZfzJWibo/G5fGQU+hixa4jLNtxmDXp2RS7PRzP2flFbo4VOjmcV0yRy0OjGhHc1imJxKqhZBwr5GBOEbmFLordHoqcbuIigkmpVZUWtaKIdNjLaEuFEKLsKKWWa63bnPYxn0vo6cthyXgIrw7h1SC8hvdndeMWUhWsJx94FDrdzFy9j0mLdrDxQM5JjykFwTYLdquFnEJXybKYsGAcdgsOuxWbRZW07IOsiiqhQVQNtRMdFkz1yGBqRDkID7ZxKLeIg8eKOFrgxG614LBbCLFbiQ4LIjY8mCqhdixK4dYarcFhtxAebCM0yIZS4PFo3FpjVQqb1YLVogi2WQi2WeTIQggB/HVC97kuF/KzYN9KyMkAZ97p1wmOgpAoCAoHmwOHPZQbwmK5/pIa7K1flbzwJEJqtSA6oT5hwfaSZJmd72R1+lFW7TnK/uxCipxuCl1unO4TX3pFLg9H8ovZcSiPrNwi8ordp7x8aJAVl1tT7PaUySYf/9IJDbIRGmQlNMhKiN1KkM1CkM2CxwO5RS5yi1w43R7vl4CVYLsFu/dcg81qwarAalFYlEIDxne5xuk2jnZcbo1SYFHGOg67hZAgGyF2C5bTfKEoBTbv8wdZLbg8mgKnm0LvexJst+KwW7BZFB4NHu8XmREDJ31Jae9jGlBAsN2Cw2bFbrPgcnsodmtcbg8eDRpjRbvV+LILsllKnt/jOXMDRSmwWixYLeB0a/KLXeQXu7EoRaTDToTDRpD39Vwejdv7BXu8zWO1KGyl3r/T0VqXbKfdasFusxBkVd7HjHXsVgvBdgtBVgsFTjdH850cK3B63zNj39msyrsfTrxPCuP98Xg0xzfTZlFGXFZVst+VgkKnh0KnmyKXG482jmg1YFUKq8V4TrfHeE+Pf76Pv5Y6/tP7nh1nsyjjM2e1YrV41+PEuhaLsY3auy+Ov2cW74paazyeE/vYYgG3B/KKXOQUOilyeQj1dpMaNzuRDhshQVZyCl0czXeSXeDEZlEl79PxfWK1KO86xRzJN9aJDLER6bBjsShcbo3T48Ht9u4fwOn2UFDsptBl/J9GOoz1HXYrHm3sf4tShHvjCQ2yUuzyUOTyUOzynPS/UrL/0cZred/XAqeb/CLjc+awW4kKsRMZYqNmlRBiw4PP+Fm9UL7XQi+tKBdyMyD34ImfBUdO3Jx54CyA4nzIy4ScA1BcqoUeFAGRNSGkCjiqGK37kKoQGg1BYWCxg8Vq/B4aYywPjYWwOAgKBSCn0EnGsSJyCp3EhgcTFxGMw27043u8Ce5wXjGZuUUczS8GTvyDFjnd5BW5yS92oTE+HFbLiX80l0dT5DK6gwq8t/wiN3nFLgqdxoeq2O3BoiDCYSc82IbNoijydiEVuTw43R7vh9lIdm6PLvlnO/6PaLeeOCIA4x/P7dEUOj3G6xa7jSRayvHk6/ZonN44jKMSI4mDkVSKnG5cHl2SAI7/rceb+IwI8MZzImkVuTy4/5ScjyfT48/j9Cb4C2W1KELtVtxak3+aL2YhystdV9bl4asbX9Df+lcLvbTgcOMWU+/c/6YwGw5tgYy1kLHOSPKF2ZCzHzI3QP6Rk5P+mQRFQFgsEaHRRIRUhbBqENsA4hoZP8PisDiiCAu2ERZso1Z06IVvZ4A63nq0WdVJ3V5/XqfI5TFaShbjS/FMnVPHW/EujzZaeaW6slxuD7lFLopdHmxW46jD6v2CNVrkRuvS5TG+aBTeb5/SNCUxGM+pKXKfOMI7vrrLrSlyGV+4IUFGqy0qxI7C+CIrdLq9X7x4jxBOHCWA8UVktSi093GPxzgaPN56dHt0SXdfsN2K1dvSN1rJ4PJ+udusCrvVOIKCE1+0x4/eSjf2NMbfGa9hxHf8Ye09avJob8u71Jeu8cVt/LSWas1qjL9XCsKDjdZ4sN3iba27OFbgJMf7e0GxiwiHnSqhxvvk8UCx202h8/iRlNFoCXfYqBoaRNXQIFweD8cKXGQXONFal+zT0o0Cu9V4jxx2K1prjhUaRwoFTnfJvnd7NLnH43C6S46C7NbjR7kat+fUIxm79/WMI2obIXYrhS432flOjhW6qBUdcqaP/UXxuRZ62oE0Plj7Af/t9F9iQ2LLITLAVQzOfPC4jFtxHuQfNrp78g8ZRwJ5mcYt/7BxNJBzAHL2nfw8FjuExZ7c8o9tCNWbQvVmRp+/I+rkT4MQQvwFv2qhF7oLWbR3EbuO7Sq/hG4LMm6lnctRwPHWf9ZWyDtkJP+8Q1B41Gj5Z26Cjd+BLnV4r6xGoq/dAS65Bhr2NO4LIcR58rmEXjuiNgB7cvbQunprk6P5E0cUJLYxbmfiKoLMjXBwo9HCLzhstO63zYMN34CyQEQ8hEQbiT2mPsRfCjVToGoSBEdKi14IcVo+l9Djw+OxKiu7j+02O5QLYws2EnT8pScv93hg/0rYPBuy043unbxD8Md0SPvgxHrKapzEjUqEak2gWmOj+yahtbFcCBGwfC6h2y124sPi2ZOzx+xQypbFYiTlhD8ddXg8cGQH7F8Nx/Z5R/AchiO7YPsCWP3JiXVjGkB8C6hSB6rUhqhaEB5nnLANiztlfL4Qwr/45H947cja/pfQz8RiMfrvz9SHn38YDqyB9N+Ni67S02D918bJ3NJsIVCvKzTqDfV7GBdjSdeNEH7FJxN6rYha/JH5B1pruYIyNBrqdjZux7ldxjDMY3tPjMg5uB42fQ+bvjXWsQYbo2wia0KtdpB0uXFi1hFpxlYIIcqAzyb0HGcO2UXZVHFUMTucysdqgyq1jFtpvccZXTe7FhsJP/eg0Z2zZAIsfsNYJzLROBqIbQC1O0LdLhAWU/HbIIQ4bz6Z0EuPdJGEfh6UMkbL1Ew5eXlxvtFls2eZMeQyayus+Rx+nwgoo18+PsW4aCruEiPRB8mFUkJUNr6Z0CONhL47ZzfN45qbHI0fCAqFulcat+M8bti3yhhOueNn2DATVkw2HguOgpSboPVtUK2RKSELIU7lkwk9MSIRhWJ3jo8OXfQFFisktjZuV/7buL477xAcWA2rP4W0SbD0HaMGTlisUesmsS2k3GxcCSuEqHA+mdCDrcFUC63GnmMBMtKlMlDKGAJZv7tx6/W8MUb+8HYj0ecehKXvwm9vGd0z9boaJ1wja0KN5sYwSiFEufLJhA4BNnSxMgqLhQ53n7wsL8tI8qunwa+vn1zioFZ7aH49NLrGSPJCiDLnuwk9ojbz98w3OwxRWlgMdBhh3DxuY7jksb2w/WdY+yV894Bxi6gJNVsaQy1bDzWunhVCXDSfTeiJEYkcLjxMnjOPMHuY2eGIP7NYIaKGcUtoDZffBxnrjROse1fA3uXGmPjf3oIeY6BJf7nQSYiL5LMJvfTQxUbRMtLCJ1RvYtyO2zYPZj8G04caRchCosEaZHTntL3DuNhJkrwQ58xidgAXqmTooq8W6RLGidMRv0DfN4yEHuQ90tr1K0zuCxO7w4ZZRveNEOKsfLaFXivCuApShi76OIvV6EdvPfTEMmchrJpqnFj9bLBx9WqrW6HlLRCVYF6sQlRyPpvQw+xhRDuiZaSLP7I7jC6XVkONfvblqbBgrHGLrmucUE1oDS3+bnTPCCEAH07oYPSjS0L3Y1abcbK0SX84vAPWzTBOpu5eaoyamftfaHMb/O1fMhRSCHw9oUfWZun+pWaHISpCdLIxUua4zM2w6FXjYqbfJ0KdTkaffL0uxoQfcjJVBCCfPSkKRj96Rn4Gha5Cs0MRFS2uIQycAKNWQLvhRvXIHx+Hdy6D15rDDw8bVSU9HrMjFaLC+HxCB0jPSTc5EmGaqknQ81m4ZynctwH6vWW00H//AD68Gt69Arb8ZNSiEcLP+XSXS/0q9QHYdGQT9avWNzkaYbrImtBqiHEryoH1M+HnF2DqdcaY9pZDjHHwMQ2ME69C+JmzttCVUrWUUvOVUuuVUuuUUveeZh2llHpDKbVVKbVGKdWqfMI9Wf0q9Qm1hbLy4MqKeDnhS4IjoOVgGJkGV78EBzfAjOFGl8zYmvDxdcaUfUL4kXNpobuA+7XWK5RSEcBypdSPWuv1pda5GmjgvbUHJnh/liurxUqLuBaszlxd3i8lfJUtCNoPN0bDZG01puLbvwZWTIGJXaHh1XDFA8YwSDmRKnzcWVvoWuv9WusV3t9zgA3An6/u6A9M0YYlQBWlVHyZR3saKdVS2HxkM3nOvIp4OeGrrHao1hiaXQc9nobRa6DrY7B7MUzsBhP+Br+9bUy6LYSPOq+TokqpJKAl8OexgglA6QHh6Zya9FFKDVdKpSml0jIzM88z1NNrGdcSj/awJnNNmTyfCBDBEXDFv2H0WujzGthDYPbD8FoLWPiSMS2fED7mnBO6Uioc+BIYrbU+diEvprV+T2vdRmvdJi4u7kKe4hTN45qjUKzKXFUmzycCjCPS6I65cx6M+NWYhm/eM/BmK2NmJhkdI3zIOSV0pZQdI5lP1Vp/dZpV9gKlp5hP9C4rdxFBEdSvWp/VB6UfXVykGs3gxqlw2w8QmQAz7oLPbjEm7hDCB5zLKBcFfABs0Fq/cobVZgK3eke7dACytdb7yzDOv9QyriWrM1fjlqp8oizU6Qh3/Ag9/gubZxv96+tmSP+6qPTOpYXeCRgCdFVKrfLeeiulRiilRnjX+Q7YDmwF3gf+WT7hnl5KtRRynblsy95WkS8r/JnFAp1GGV0xjiiYPgxeTIZXmsJXd0lyF5XSWYctaq0XAX85nktrrYF7yiqo85USlwLAqoOraFi1oVlhCH8U38Ko2b7rVzjwhzHkcd1XkP473Pw5xMoFbaLy8OlL/49LjEgkxhHDqoOrzA5F+CNbsFH4q9O9MOgDGPoNFGYbwx13LDQ7OiFK+EVCV0qRUi1FRrqIilG7A9w515gvdcoA+PlFcLvMjkoI/0joYHS77MnZw6GCQ2aHIgJB1SS4Yw40uxbmP2sUAju83eyoRIDzm4TesnpLAKmPLiqOIwqumwjXfQCZm2BCJ/j+ITiy0+zIRIDym4TePLY58WHxfL31a7NDEYGm+SD452Jo3Bd+fx/eaGmMisnJMDsyEWD8JqFblIUB9QewZP8S9uXuMzscEWiiEuHa9+DeNcaUeJtnw8TuRpVHISqI3yR0gAH1BwDwv63/MzUOEcCiEqDHGLjtO3AXwQc9YfvPZkclAoRfJfSa4TXpWLMjM7bOkKtGhblqtoR/zDUm3fj4WvjmXshYf/a/E+Ii+FVCBxjYYCAH8g6wZP8Ss0MRga5KLbhjNrS8xSj0NaEjpPaBrXOl6JcoF36X0LvW6kqV4Cp8teV0NcSEqGCOKOj7ujHfafenIGub0WKf1BO2zZfELsqU3yX0IGsQfer2Yd6eeRwpPGJ2OEIYQqPhsv+De1dB73FwdA98NAC+/IfUXhdlxu8SOsC1Da7F5XFJK11UPrZgaHcnjFoJXR6FtV/CBz3g8A6zIxN+wC8TeoOqDegQ34GpG6ZS7C42OxwhTmV3wJUPwuAvIHsPvNcZts0zOyrh4/wyoQPc3ux2Mgsy+WbbN2aHIsSZNegOwxd4R8MMgqXvSb+6uGB+m9A7xHegcXRjUtel4tEes8MR4syi6xp1YRpcBd//G769D9xOs6MSPshvE7pSitub3c7OYzuZv3u+2eEI8deCI4zp7y77P0ibBJN6QeZms6MSPsZvEzpA9zrdSQxPZNLaSWg5jBWVncVqDG0c9CEc3gbvXAaLXpPSvOKc+XVCt1lsDGs6jDWH1pCWkWZ2OEKcm2bXwj3LoEEP+OlJSO1tDHMU4iz8OqED9K/fn9iQWMavGi+tdOE7wqvB3z+GaycaJQPevRw2fW92VKKS8/uE7rA5GN5iOMszlvPbvt/MDkeIc6cUtLge7voZomrBJzfCT0/LKBhxRn6f0AEGNRhEQngCr698XVrpwvfE1IM7foRWQ2HRK/Dt/eCRkVviVAGR0O1WO3dfejfrs9Yzd/dcs8MR4vzZHUZNmE6jIe0DmDkSpKKo+JOASOgAfer2oW5UXd5c+aaU1hW+SSljFEznR2DVVJj2d9i/2uyoRCUSMAndarEysuVItmdvZ+a2mWaHI8SFUQo6/weufhF2L4F3r4ApA2DXYrMjE5VAwCR0gO61u9M8tjlvrHyDPGee2eEIceHa3wX3rYPuTxvT3KX2Maa9EwEtoBK6UoqH2z3MoYJDvLvmXbPDEeLiOKLgstHwrzSo0Rw+Hwp7lpkdlTBRQCV0gOZxzRlQfwAfrf+Indk7zQ5HiIsXHGFUbYysCVOvh4MbzY5ImCTgEjrAva3uxWF18MLvL8gwRuEfwuNgyFdGvfXJfWHN5zK0MQAFZEKPDYnl7kvvZtHeRfycLjOyCz9RNQlu/dpoqX91pzFxRrqUvAgkAZnQAW5qfBP1ourx7NJnyS3ONTscIcpGtcZw53wYMAGy042kvnKq2VGJChKwCd1usfN0p6c5mH+QV5e/anY4QpQdiwVSbjZOliZfCV//05g4Q/i9gE3oAJfGXcqQxkP4fPPnLN2/1OxwhChbwRFw82fQqI8xccbCcVIHxs8FdEIHGNlyJHUi6/Dk4ifJd8rs68LP2ILh+lRofgPM+y98OhhyM82OSpSTgE/oDpuDMX8bw77cfbyx8g2zwxGi7FntMPBduOpZ2PojTOgopXj9VMAndIBW1Vvx90v+zrQN01h7aK3Z4QhR9iwW+NtIGP4zhNcwSvHOf066YPyMJHSvUa1GERcSx5jfxuDyyJRfwk9VbwJ3zoOUwfDz8/DNKJnizo9IQveKCIrgofYPseHwBqZtmGZ2OEKUH1sQ9B8Plz8AK6bAZ4PBWWB2VKIMSEIvpXvt7lyReAVvrXqL/bn7zQ5HiPKjFHR7HHqPg80/wOxHzY5IlAFJ6KUopXik/SMAjF06VsoCCP/X7k7oONKYNENOlPq8syZ0pdQkpdRBpdRpzxYqpTorpbKVUqu8tyfKPsyKkxCewD0p97AgfQHf75APuAgA3Z4wqjV+fQ/kZJgdjbgI59JCTwV6nWWdX7TWKd7bmIsPy1y3NL6FFrEteG7ZcxwqOGR2OEKUL1swXPcBFOfD/+6Wqe182FkTutZ6IXC4AmKpNKwWK2M6jSHPmcdzS58zOxwhyl/cJdDzWdg2F56Nh7c7wmdDYP1MGdroQ8qqD72jUmq1Uup7pVTTM62klBqulEpTSqVlZlbuq9XqVanH3ZfezZxdc/hx149mhyNE+Wtzu3FVafvhRuXGvSvg8yHGNHebvpfE7gPUuZz4U0olAbO01s1O81gk4NFa5yqlegOva60bnO0527Rpo9PSKndpT6fHyeBvB5NZkMnMATOJCIowOyQhKo7bBWu/gAXPw5Ed0HII9HvTGCEjTKOUWq61bnO6xy66ha61Pqa1zvX+/h1gV0rFXuzzVgZ2i50nOz5JVkEW41eNNzscISqW1QaX3ggjf4dOo2HlR7BAuiArs4tO6EqpGkoZX9lKqXbe58y62OetLJrGNuWGS27gk42fsCFrg9nhCFHxrHbo/hSk3AI/vwDLU82OSJzBuQxb/AT4DbhEKZWulLpDKTVCKTXCu8ogYK1SajXwBnCj9rMB3P9q+S+qBFfhmaXP4NEyrZcIQEpB39egfneYdR9skfNKldE59aGXB1/oQy9t5raZPLroUZ7q+BTXNbzO7HCEMEdRLkzqCTkH4J9LjLlMRYUq1z70QNG3bl9aVWvFy8tf5kDeAbPDEcIcweFw7ftQdAxmjZaRL5WMJPRzpJRiTCejEuMjix7BLRdfiEBVvQl0fQw2zoI1n5kdjShFEvp5qBNZh4fbPczvB37nw3Ufmh2OEObpOBJqd4TvHjQmoxaVgiT08zSg/gB6JvVk/MrxMhmGCFwWKwx4GzwueL8r/PwS5EmZDLNJQj9PSike7/A4saGx/GfhfyhwSR1pEaCi68Kt/zMKe81/Bl5pAvOekX51E0lCvwBRwVE82+lZdufs5t3V75odjhDmqdUObvkS7lkGjfvCwpdgzmOS1E0iCf0CtYtvR/96/Zm8bjKbj2w2OxwhzBV3CVw3EdoNh9/egp+ekqRuAknoF+GBNg8QERTB0789LRccCaEUXP0itL4Nfn0Nvr1PTphWMEnoF6GKowr/bvtv1mSuYfqm6WaHI4T5lIJrXoF2d0Hah/Bac/h0MOxeYnZkAUES+kXqU7cP7ePb89qK19ievd3scIQwn8UCvV+Ee1dDp3th92/w4dWwSiZfL2+S0C+SUoqn//Y0QdYgRs4dyZHCI2aHJETlULWOUdTr3jWQfIUxG9KSd8yOyq9JQi8DCeEJvN7ldTLyMhg9fzTF7mKzQxKi8ggOh5s/h0Z94If/GEMb3S6zo/JLktDLSEq1FJ697FlWHFzBU4ufws8KTgpxcWzBcP1kowTvwpeMWZB2LTY7Kr8jCb0M9UruxciUkXyz/RueX/a8JHUhSrPaoP9b8PePjeJeH14NM0ZAgXRTlhWb2QH4m+EthnOs+BhT1k/Boiw82PZBlEzZJYRBKeMCpHrd4JeXjeGNO36Bge9A8uVmR+fzpIVexpRSPNDmAW5pfAsfb/iYcWnjpKUuxJ8FhUK3x+GOH43umMl9jYuR3E6zI/Np0kIvB0opHmz7IG7tZsr6KVR1VOUfzf9hdlhCVD4JrWDEL/DDQ7DoVUhPg+tTIcwvpiWucNJCLydKKR5q9xBXJ1/N6yte57vt35kdkhCVU1AY9HsTBrwDe5bBe51h/2qzo/JJktDLkUVZeKbTM7Su3prHfn2M5RnLzQ5JiMor5Sa4/QfQHpjYA765Fw7KxOznQxJ6OQuyBvF6l9dJCE9g1LxRcjWpEH8loRUM/xku/Tus/hTe7gCT+8Gaz435TMVfkkmiK0h6TjqDvxuMw+rgo94fUS20mtkhCVG55WXBilSjJkz2HrCFQKNroOezEFHD7OhMI5NEVwKJEYm83f1tjhQd4Z8//ZPcYmltCPGXwmLg8vuN0gG3/QApN8PGb+GzW8AlV2OfjiT0CtQ0pimvdn6VbUe3MXqBlAgQ4pxYLFCnI/R5xRivnv47zH7Y7KgqJUnoFaxTQiee7vQ0S/cvZcLqCWaHI4RvaToA/jYKfp8Iqz4xO5pKRxK6CfrV68eA+gP4cO2HbDq8yexwhPAt3Z6EpMth1mjYt8rsaCoVSegmeaDNA0QFR/Hk4idxeaTynBDnzGrzXnwUB5/cBMf2mx1RpSEJ3SRRwVE83P5h1mWtY+qGqWaHI4RvCYuFmz41inx9ciMU55sdUaUgCd1EPev0pHNiZ95a+RaL9y6WeUmFOB81msGgSXBgDcwYbiT1ohwoOBqwE1TLOHSTHcg7wI2zbiSrMIvqodXpndyboU2HEhMSY3ZoQviG38bD7EdOXlajhdHXXr+bUeHRj/zVOHRJ6JVAgauAn/f8zKzts1i0dxEp1VKY1HMSFiUHUEKcldaw9ks4uhssNvA4YXmqcb9OJ2hzO8SnQHRdYwikj5OE7kP+t/V/PP7r4zzY9kGGNBlidjhC+CZXMayYDD+/CHkHjWVB4dD8erj6BaNkr4/6q4Qu5XMrmf71+vPTrp94fcXrXJZwGclRyWaHJITvsQVBuzuh9TDI3GhUb9z1Gyz/EDI3GbMmhflft6bvH3/4GaUUT3Z8EofNwWOLHpMhjUJcDKsdajSHlrfAgPFw3QewdzlM7AYH/jA7ujInCb0SiguN49H2j7Lm0Bre/+N9s8MRwn80HwTDvoXiXHjnMhjfAeb+FzI3mx1ZmZCEXkn1SupF37p9mbBqAgv2LDA7HCH8R622cPdi6PW8MZ590SvwTiejXK+Pk4ReSSmleKLjEzSOacxDvzzE9qNSR12IMhNeDTrcDcNmwX0boVZ7mHEX/PgkeNxmR3fBJKFXYg6bg9e7vI7D6mDU/FFkF2WbHZIQ/ieiOgyZAa1vg19fg9Q+sOx9OLTF5y5QkoReydUIq8GrXV5lb+5eRs0bJXXUhSgPVjv0eRWueQWOpcN3D8BbbYzbrsVmR3fOJKH7gJbVWvL85c+zJnMNd865U1rqQpQHpaDtHTD6Dxi1ykjwHhd82BvmPA6uIrMjPKuzJnSl1CSl1EGl1NozPK6UUm8opbYqpdYopVqVfZiiZ1JPXu3yKpuPbGbYD8M4VHDI7JCE8F/RycYVpiN+hdZDYfEbMKET/PySMdxRa8jJgK1zYeVUKDxmdsTAOVwpqpS6AsgFpmitm53m8d7Av4DeQHvgda11+7O9sFwpemGW7F/CqHmjiA+LZ3KvyVRxVDE7JCH83+bZ8PMLxhh2AHsoOEtVeIyIh55joenAcq8dc9GX/iulkoBZZ0jo7wILtNafeO9vAjprrf+ySLEk9Av3+4HfGfHjCBpFN+L9q94n1B5qdkhCBIacDNgyx7jyNKYeVG8KymIUB9u/Gup2gX5vQJXa5RZCeU8SnQDsKXU/3bvsdIEMV0qlKaXSMjMzy+ClA1PbGm158YoXWZu1lvsW3IfT7TQ7JCECQ0R1aDUErhlnDHtMvgKSLoM750PvcZCeZlywtH6mKeFV6ElRrfV7Wus2Wus2cXFxFfnSfqdbnW480eEJft33K3fPvZul+5diVqE1IQKexWrUjhnxC0TXg8+HwLf3g7OgYsMog+fYC9QqdT/Ru0yUs+saXsej7R9l4+GN/GPOP+j/dX++2faN2WEJEbiik+H22dBxpDGR9dsdYNu8Cnv5sqi2OBMYqZT6FOOkaPbZ+s9F2bmx0Y0MqD+A2TtnM3XDVB5Z9Ai5zlxuanST2aEJEZhsQdDzWWjYE74ZDR8NNMr21u18Yp1qTSCh7AcEnssol0+AzkAskAE8CdgBtNbvKKUU8BbQC8gHbtNan/Vsp5wULXtOj5P7F9zP/D3zeabTM/Sv39/skIQIbM5Co1bML68YE28c12k09Hj6gp5SJrgIIEXuIkbOHcmyA8t4quNTXJV0FWH2MLPDEiKwFRwx5js9LjgCQqpe0FNJQg8w+c58Rvw0gpUHV6JQ1ImsQ/v49vy77b8JtvruTC1CCJmxKOCE2kN5/6r3Wbp/Keuz1rMuax2fbfqMrIIsxl05DqvFanaIQohyIAndTwVbg7ki8QquSLwCgCnrpvBS2ks8u/RZHu/wOMrPZkIXQkhCDxi3Nr2VrMIsJq2dRExIDPek3GN2SEKIMiYJPYCMbjWaw4WHeWf1OySEJzCg/gCzQxJClCEpnxtAjs+C1D6+PWN+G8PKgyvNDkkIUYYkoQcYu8XOy1e+TM3wmoyeP5q9uXJRrxD+QhJ6AIoKjuLNrm/i9DgZOXckP+76UWZCEsIPSEIPUMlRybzS+RUOFRzivgX3cflnlzPixxEcyDtgdmhCiAskCT2AdYjvwPwb5pPaK5UhTYaw8uBK7p1/L4WuQrNDE0JcAEnoAc5msdG6emvua30fz1/+POuz1vPfJf+VUrxC+CBJ6KJEl9pd+GfKP5m5bSZTN0w1OxwhxHmScejiJHe1uIuNWRsZlzaOtIw0WlVrRZsabWgc3ViuLhWikpOELk5iURbGXj6Wl9Ne5rd9vzF391wA+tXrx1N/ewq7xW5yhEKIM5GELk4RZg/jiY5PAJCRl8Fnmz7j/T/e53DhYV6+8mWZlFqISkr60MVfqh5WnVGtRvFkxydZvG8xd865k61HtspJUyEqIWmhi3MyqOEgqgZX5cGFDzJw5kBqhtXksoTLuLHRjTSo2sDs8IQQyAQX4jxl5GWwcO9Cfkn/hSX7l+D0OLkn5R5ua3qb1FkXogL4zIxFTqeT9PR0CgvlwhZfYA2y8mH6h8zaOYtL4y7lucufo1ZELbPDEsKv+UxC37FjBxEREcTExMgQuUpOa01WVhY5OTms0+sYu3QsIdYQUq9OlaQuRDn6q4ReqU6KFhYWSjL3EUopYmJiKCwspE/dPqT2SqXQXcidc+6UejBCmKRSJXRAkrkPKb2vGlZtyLs93uVo0VFJ6kKYREa5iDLTLLYZ47uNZ8SPI+jxRQ+qBlclKSqJDvEdGNJkCBFBEWaHKIRfq3QtdF+QmprKyJEjzQ6jUmpdvTXTrpnG/a3vp2vtrgBMWD2B3l/1Zsq6KRS7i02OUAj/JS10UeYaVG1w0tj0DVkbeG3Fa7yU9hKfbvqUsZeNJaVainkBCuGnKm1Cf/qbdazfd6xMn7NJzUie7Nv0L9fZuXMnvXr1okOHDixevJi2bdty22238eSTT3Lw4EGmTj25CuGwYcNwOBykpaVx7NgxXnnlFfr06VOmcfu6xjGNebfHuyzet5gxv41h6A9Dub3Z7fzz0n9it0ptGCHKinS5nMbWrVu5//772bhxIxs3bmTatGksWrSIcePGMXbs2FPW37lzJ8uWLePbb79lxIgRMo7+DP5W82980fcL+tfrz8Q/JnL9N9fz/Y7vcXvcZocmhF+otC30s7Wky1NycjLNmzcHoGnTpnTr1g2lFM2bN2fnzp2nrH/DDTdgsVho0KABdevWZePGjaSkpFRs0D4iPCicMZ3G0LV2V15d/ioPLnyQtyLeYniL4fSt1xeLkjaGEBdK/ntOIzg4uOR3i8VSct9iseByuU5Z/89DLWXo5dl1rtWZGf1n8GrnVwmzh/HYr49x++zb2ZG9w+zQhPBZlbaF7kumT5/O0KFD2bFjB9u3b+eSSy4xOySfYFEWutfpTrfa3fjf1v/xUtpLDJo5iJsa3URMSAwAUcFR9K3XV+qwC3EOJKGXgdq1a9OuXTuOHTvGO++8g8PhMDskn6KUYmCDgVyeeDljl45l8vrJJz3+/Y7vebnzy0QGRZoUoRC+oVLVctmwYQONGzc2JZ4LNWzYMPr06cOgQYPMDsUU5bHPClwFJfXWZ++czZglY6gdUZvx3caTGJFYpq8lhK/xmVouQgCE2EIItYcSag9lYIOBvNfjPQ4VHOLmb29m8rrJ5Bbnmh2iEJWSJPSLlJqaGrCt84rStkZbPu79MXWr1GVc2ji6f9Gdl35/iX25+05ab3/uft5e9TabDm8yKVIhzCV96MInJEclk9orlXWH1jFl/RSmbpjK1A1T6V6nO/3q9WPu7rnM3DoTl3aRui6VF694kc61OpsdthAVSlrowqc0jW3KC1e8wA/X/cCtTW5l8d7F3DP3HmZtm8WghoOY1nsayVHJjJo3isnrJsvcpyKgyElRcVHM3md5zjwW71tMSlwKcaFxgHFS9ZFfHuGn3T9RN6ouraq3olW1VnSu1VkqPgqf91cnRaXLRfi0MHsYPer0OGlZiC2Elzu/zGebPmNh+kJ+2PEDX2z+grpRdZl41cSSxC+Ev5EuF+GXLMrCTY1uYkL3CSy6cRHju41nf95+hv0w7LSTb+QW5/LS7y/x/LLncXqcJkQsxMU7p4SulOqllNqklNqqlHroNI8PU0plKqVWeW//KPtQK4/yroe+YMECqdhYhqwWK1ckXsF7Pd7jcOFhhv0wjBUZK8guygZg7q659P+6Px+t/4ipG6Zy/4L7pW678Eln7XJRSlmB8UAPIB34XSk1U2u9/k+rfqa1Lrss9/1DcOCPMns6AGo0h6ufL9vnvEAulwubTXq8KlJKtRTev+p9hv84nKE/DAUg3B5OrjOXhlUb8mrnV1mXZUx4PXLuSF7r8hqh9lCToxbi3J1LRmkHbNVabwdQSn0K9Af+nND9QnnWQ09NTeWrr74iNzcXt9vNd999x7/+9S/Wrl2L0+nkqaeeon///if9zVNPPUV4eDgPPPAAAM2aNWPWrFkkJSWVy/b7u2axzZg5YCZrMtewJ2cPe3L2kByVzA2X3IDdYqdFXAtCbCE8ufhJBn49kIZVG5IQkUCj6EZck3yN1G8Xldq5JPQEYE+p++lA+9Osd51S6gpgM/B/Wus9f15BKTUcGA5G/ZO/ZGJLeuvWrUyfPp1JkybRtm3bknroM2fOZOzYsQwYMOCk9Y/XQ9+2bRtdunRh69atZ6znsmLFCtasWUN0dDSPPPIIXbt2ZdKkSRw9epR27drRvXv3CtjCwBYbElsyPd7pDKg/gKrBVflyy5fszd3LsgPLyHfl887qdxhx6Qj61O2DzSJHV6LyKauTot8ASVrrFsCPwOTTraS1fk9r3UZr3SYurvKONDheD91isVxwPfQz6dGjB9HR0QDMmTOH559/npSUFDp37kxhYSG7d+8ur80S5+HKWlfyRtc3+LLflyy5eQkTuk8gMiiSx399nGtnXsucnXPwaE/J+rnFuazOXC0nVIWpzqWZsReoVep+ondZCa11Vqm7E4EXLz4085RnPfSwsLCS37XWfPnll6eU283IyCj53Waz4fGcSBwyG1LFU0pxWcJldKrZiZ92/8RbK9/i/p/vp3F0YwbUH8CyA8v4Jf0Xij3FVAupxqBLBjGowSAZHikq3Lm00H8HGiilkpVSQcCNwMzSKyil4kvd7QdsKLsQK7/p06fj8XjYtm3bedVD79mzJ2+++WbJ1YwrV648ZZ2kpCRWrFgBGN01O3bIBBBmUUrRo04Pvur3Fc90eobsomyeW/YcazLXcP0l1zP2srE0qNqAt1e9zVVfXMWIn0bw9davySnOMTt0ESDO2kLXWruUUiOB2YAVmKS1XqeUGgOkaa1nAqOUUv0AF3AYGFaOMVc6F1oP/fHHH2f06NG0aNECj8dDcnIys2bNOmmd6667jilTptC0aVPat29Pw4YNy2MTxHmwWqz0r9+fq5OvZtexXdSNqovVYgWgb72+7MzeyVdbv2L2jtk89utjWBdbCQ8KJ8QWQrg9nGsbXMvNjW4u+Rshyopc+n+RpB667+2ziqK15o9Df7AwfSHZRdkUuArYk7OHFQdX0DSmKU/97SmigqJYsn8JqzJX0TG+I72Se5kdtqjk5NJ/IUyglKJFXAtaxLUoWaa1ZvbO2Ty37Dmu/+b6kuUOq4OvtnxFWkYaD7Z9kCBrkBkhCx8nCf0ipaamnrJs9uzZ/Oc//zlpWXJyMjNmzKigqERlpZSiV3IvOtbsyLQN0wgPCqd9fHuSo5J5c8WbfLjuQ9YdWsfVyVezPXs7W49uxWF10CKuBZfGXUrN8Jo4PU6cHidWZSUqKIooRxQR9giZnFxIl4u4OLLPytbcXXN57NfHyHXmUjW4KvWq1CPPmcfmI5txa/cZ/y4qOIqutbpyVdJVtK/RXi6A8mPS5SKEj+hWpxvt49tT5C4iJiSmZHmBq4B1h9aRVZhFkCUIu9WOy+Miuyibo0VH2XB4A3N2zWHG1hnEOGK4t9W99K/fH4uS+nuBRBK6EJVMeFA44YSftCzEFkKbGqdtlJUochexeO9iJq2dxBOLn+DTTZ9y96V3E2QNIt+Zz7HiY6TnpLM3dy9HCo/QIq4FlyVcRrPYZnLlq5+QLhdxUWSfVT5aa77b8R2vLH+Fg/kHT3rMqqzUCKtBRFAEm49sxqM9RAZFclOjm7i16a1EBkWe12tlF2Vjt9iliFkFki4XIQKIUopr6l5Dl1pdWJ25mmBrMGH2MMKDwqkWWg27xehfzy7KZsn+JXy/43veXfMu0zZMY0jTIXSI70B8WDxxIXEopSh0FZLvyifMHkaILQSA7Ue38+G6D5m1fRYOq4PBjQczpMkQooKjzNz0gCct9AuQmppKWloab731Vrk8/4IFCxg3btwpFxmdi/DwcHJzc8shqtPzlX0m/trGwxsZv2o8C/YsKFlmVVY82oPmRI6ICo4ixhHD9uztOKwOBtQfQFZhFj/u+pEwexjta7THoiwopYgPi6dP3T40im4kI3DKkE+20F9Y9gIbD5+5yNWFaBTdiP+0+8/ZV6wAvlgPXWuN1hqLRU60+ZtG0Y14s+ubpOeks/PYTvbl7uNA3gGsFiuhtlBCbCHkFOeQkZ9BRn4GPZN6cmOjG4l2GIXmNh/ZzMQ1E9lydAtgfFYW7FnAlPVTqF+lPp1rdaZ6aHXiQuKoEVaDpKgkwuxhaK1Zn7We2Ttns+vYLoa3GE7T2KYmvhO+zbcySgXwl3roubm59O/fnyNHjuB0OnnmmWfo378/TzzxBNHR0YwePRqARx99lGrVqnHvvffy0ksv8fnnn1NUVMTAgQN5+umn2blzJz179qR9+/YsX76c7777jjp16lzYmysqvcSIRBIjEs/77xpWbciLV55cky+7KJvZO2czc9tMJq2ddFJ1SoBqodWwKiv78/ZjUzbCgsIY/N1ghjUdxt0pdxNsDUacn0qb0M1sSftDPXSHw8GMGTOIjIzk0KFDdOjQgX79+nH77bdz7bXXMnr0aDweD59++inLli1jzpw5bNmyhWXLlqG1pl+/fixcuJDatWuzZcsWJk+eTIcOHcokNhEYooKjuOGSG7jhkhtwe9wcLjxMZkEm+3P3s+PYDnZk7yDPmcfdl95N19pdUUox7vdxfLD2A+btmcfdl95Njzo9SkbgFLuL2Xh4I1UdVUkITzjtkMwidxHL9i8j1B5Ks9hmZ/xSOFRwCDBq4/uTSpvQzXS8HjpwwfXQU1JSTvvcf66HPnPmTMaNGwdQpvXQtdY88sgjLFy4EIvFwt69e8nIyCApKYmYmBhWrlxJRkYGLVu2JCYmhjlz5jBnzhxatmwJGC38LVu2ULt2berUqSPJXFwUq8VKXGgccaFxNIlpcsb1xnQaQ8+knjy/7HkeXPggCeEJ9K/Xny1Ht/Dr3l/Jd+UDEGwNJikyiTqRdagTWYf48HhWZqxk/p755DqNc0hBliCaxTajbpW6xDhiiAmJYW/OXhbvX8yWI1uwKAsd4zvSv35/utTqgsN2bkX1KjNJ6KfhD/XQp06dSmZmJsuXL8dut5OUlFTyt//4xz9ITU3lwIED3H777SWxPPzww9x1110nPc/OnTtPilmI8tYpoRNfD/iaBXsW8OHaD3l79dtUC6lG77q96RjfkZziHLZnb2d79nY2Ht7I3N1zcWs3EUERdK/TnavqXIXT42RFxgpWHlzJvN3zOFJ4BI3GbrHTqnor/q/1/5HnzOObbd/w4MIHcVgdtK7Rmo7xHbky8UqSopLOKdZ8Zz5pGWkkRSZRK6KW6Sd/JaGXgenTpzN06FB27NhxQfXQ33zzTZRSrFy5sqSFfFxSUlLJaJfzqYeenZ1NtWrVsNvtzJ8/n127dpU8NnDgQJ544gmcTifTpk0rieXxxx9n8ODBhIeHs3fvXux2uXxcmMOiLHSt3ZWutbtyqOAQMY6YMyZLp8fJgbwD1AitcVLJg9LTDLo8Lo4WHT1p6CXAPSn3sOzAMubvns9v+39jXNo4xqWN47KEy7it6W20rdH2pKGbEUER2C12jhUf49ONn/Lx+o85UnQEgCrBVWge25xG0Y1oGN2Q5Mhk0nPS+ePQH2w8spFqIdVoHtec5rHNqV+lfrlczCUJvQxUxnrogwcPpm/fvjRv3pw2bdrQqFGjkseCgoLo0qULVapUwWo1anJfddVVbNiwgY4dOwLG8MePP/645HEhzHK2fm67xU6tiFp/uY7NYjvt81iUhQ7xHegQb3QpHsg7wNdbv2baxmncMecO4kLiyHflk+fMK/mbMHsYbo+bQnchlydczo2NbuRg/kHWZK5hTeYaFu9bfFLdHZuyUbdKXdYdWseMrUaBvpsb3czD7R8+5/fgXMk49Ivki/XQPR4PrVq1Yvr06TRo0OCinssX95kQZ1PkLuKbbd+wPGM5VYKrEBMSUzJ0M7soG5fHxcAGA097PqDIXcT2o9vZeWwnCeEJXBJ9CcHWYLTWpOem80fmH9SOrE2z2GYXFJtPjkMX5WP9+vX06dOHgQMHXnQyF8JfBVuDGdRwEIMann9DLdgaTOOYxjSOObmho5SiVkStsx5NXAxJ6BfJjHroWVlZdOvW7ZTlc+fOJSYm5jR/cUKTJk3Yvn17mcQhhKhcKl1C11qbfqb4YvXs2ZOePXuW2/PHxMSwatWqcnv+c2VWd50Q4vQq1TXcDoeDrKwsSRQ+QGtNVlbWOZ8AFkKUv0rVQk9MTCQ9PZ3MzEyzQxHnwOFwkJh4/peJCyHKR6VK6Ha7neTkZLPDEEIIn1SpulyEEEJcOEnoQgjhJyShCyGEnzDtSlGlVCaw66wrnl4scKgMw/EVgbjdgbjNEJjbHYjbDOe/3XW01nGne8C0hH4xlFJpZ7r01Z8F4nYH4jZDYG53IG4zlO12S5eLEEL4CUnoQgjhJ3w1ob9ndgAmCcTtDsRthsDc7kDcZijD7fbJPnQhhBCn8tUWuhBCiD+RhC6EEH7C5xK6UqqXUmqTUmqrUuohs+MpD0qpWkqp+Uqp9UqpdUqpe73Lo5VSPyqltnh/VjU71vKglLIqpVYqpWZ57ycrpZZ69/lnSqkgs2MsS0qpKkqpL5RSG5VSG5RSHQNhXyul/s/7+V6rlPpEKeXwx32tlJqklDqolFpbatlp968yvOHd/jVKqVbn81o+ldCVUlZgPHA10AS4SSl16hxQvs8F3K+1bgJ0AO7xbudDwFytdQNgrve+P7oX2FDq/gvAq1rr+sAR4A5Toio/rwM/aK0bAZdibLtf72ulVAIwCmijtW4GWIEb8c99nQr0+tOyM+3fq4EG3ttwYML5vJBPJXSgHbBVa71da10MfAr0NzmmMqe13q+1XuH9PQfjHzwBY1sne1ebDAwwJcBypJRKBK4BJnrvK6Ar8IV3Fb/abqVUFHAF8AGA1rpYa32UANjXGNVeQ5RSNiAU2I8f7mut9ULg8J8Wn2n/9gemaMMSoIpSKv5cX8vXEnoCsKfU/XTvMr+llEoCWgJLgepa6/3ehw4A1c2Kqxy9BjwIeLz3Y4CjWmuX976/7fNkIBP40NvNNFEpFYaf72ut9V5gHLAbI5FnA8vx731d2pn270XlOF9L6AFFKRUOfAmM1lofK/2YNsab+tWYU6VUH+Cg1nq52bFUIBvQCpigtW4J5PGn7hU/3ddVMVqjyUBNIIxTuyUCQlnuX19L6HuB0lNmJ3qX+R2llB0jmU/VWn/lXZxx/PDL+/OgWfGVk05AP6XUTozutK4Y/ctVvIfl4H/7PB1I11ov9d7/AiPB+/u+7g7s0Fpnaq2dwFcY+9+f93VpZ9q/F5XjfC2h/w408J4JD8I4iTLT5JjKnLff+ANgg9b6lVIPzQSGen8fCnxd0bGVJ631w1rrRK11Esa+nae1HgzMBwZ5V/Or7dZaHwD2KKUu8S7qBqzHz/c1RldLB6VUqPfzfny7/XZf/8mZ9u9M4FbvaJcOQHaprpmz01r71A3oDWwGtgGPmh1POW3jZRiHYGuAVd5bb4z+5LnAFuAnINrsWMvxPegMzPL+XhdYBmwFpgPBZsdXxtuaAqR59/f/gKqBsK+Bp4GNwFrgIyDYH/c18AnGeQInxhHZHWfav4DCGMm3DfgDYxTQOb+WXPovhBB+wte6XIQQQpyBJHQhhPATktCFEMJPSEIXQgg/IQldCCH8hCR0IYTwE5LQhRDCT/w/r5Zk9OLvOv8AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "for i in ['mlp', 'mlp_relu', 'mlp_relu_layer']:\n", " plt.plot(torch.tensor(loss[i]).view(-1, 10).mean(dim=-1), label=i)\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0 train 0.200 val 0.197 test 0.217\n", "epoch 1 train 0.129 val 0.157 test 0.147\n", "epoch 2 train 0.100 val 0.133 test 0.140\n", "epoch 3 train 0.088 val 0.142 test 0.125\n", "epoch 4 train 0.088 val 0.142 test 0.139\n", "epoch 5 train 0.075 val 0.137 test 0.132\n", "epoch 6 train 0.063 val 0.139 test 0.120\n", "epoch 7 train 0.059 val 0.118 test 0.119\n", "epoch 8 train 0.056 val 0.133 test 0.119\n", "epoch 9 train 0.055 val 0.118 test 0.125\n", "epoch 10 train 0.046 val 0.112 test 0.135\n", "epoch 11 train 0.042 val 0.128 test 0.122\n", "epoch 12 train 0.039 val 0.127 test 0.147\n", "epoch 13 train 0.044 val 0.115 test 0.126\n", "epoch 14 train 0.031 val 0.116 test 0.120\n", "epoch 15 train 0.031 val 0.134 test 0.133\n", "epoch 16 train 0.027 val 0.117 test 0.137\n", "epoch 17 train 0.024 val 0.135 test 0.117\n", "epoch 18 train 0.027 val 0.141 test 0.145\n", "epoch 19 train 0.027 val 0.133 test 0.178\n" ] } ], "source": [ "# 模型过拟合\n", "model = nn.Sequential(\n", " nn.Linear(784, 30, bias=False), nn.LayerNorm(30), nn.ReLU(),\n", " nn.Linear( 30, 20, bias=False), nn.LayerNorm(20), nn.ReLU(),\n", " nn.Linear( 20, 10)\n", ")\n", "\n", "loss['mlp_relu_layer'] = train_model(model, optim.Adam(model.parameters(), lr=0.01), epochs=20)" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "m = nn.Dropout(0.5)\n", "x = torch.randn(5, requires_grad=True)" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 1.3846, 0.5061, 0.3079, 1.1239, -1.9939], requires_grad=True)" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([0.0000, 1.0123, 0.0000, 0.0000, -0.0000], grad_fn=),\n", " tensor([0., 4., 2., 0., 0.]))" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m.train()\n", "l = m(x)\n", "l.sum().backward()\n", "l, x.grad" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 1.3846, 0.5061, 0.3079, 1.1239, -1.9939], requires_grad=True)" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m.eval()\n", "m(x)" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m.training" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0 train 0.265 val 0.260 test 0.261\n", "epoch 1 train 0.182 val 0.200 test 0.204\n", "epoch 2 train 0.153 val 0.187 test 0.196\n", "epoch 3 train 0.161 val 0.184 test 0.173\n", "epoch 4 train 0.179 val 0.178 test 0.170\n", "epoch 5 train 0.134 val 0.167 test 0.162\n", "epoch 6 train 0.122 val 0.171 test 0.164\n", "epoch 7 train 0.124 val 0.157 test 0.157\n", "epoch 8 train 0.116 val 0.160 test 0.149\n", "epoch 9 train 0.128 val 0.154 test 0.158\n", "epoch 10 train 0.104 val 0.153 test 0.147\n", "epoch 11 train 0.121 val 0.153 test 0.149\n", "epoch 12 train 0.117 val 0.142 test 0.158\n", "epoch 13 train 0.099 val 0.143 test 0.151\n", "epoch 14 train 0.110 val 0.159 test 0.139\n", "epoch 15 train 0.110 val 0.159 test 0.138\n", "epoch 16 train 0.105 val 0.138 test 0.151\n", "epoch 17 train 0.096 val 0.131 test 0.142\n", "epoch 18 train 0.089 val 0.153 test 0.156\n", "epoch 19 train 0.106 val 0.155 test 0.164\n" ] } ], "source": [ "model = nn.Sequential(\n", " nn.Linear(784, 30, bias=False), nn.LayerNorm(30), nn.ReLU(), nn.Dropout(0.2),\n", " nn.Linear( 30, 20, bias=False), nn.LayerNorm(20), nn.ReLU(), nn.Dropout(0.2),\n", " nn.Linear( 20, 10)\n", ")\n", "\n", "loss['mlp_relu_layer_dropout'] = train_model(model, optim.Adam(model.parameters(), lr=0.01), epochs=20)" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0 train 0.286 val 0.292 test 0.287\n", "epoch 1 train 0.256 val 0.253 test 0.235\n", "epoch 2 train 0.247 val 0.261 test 0.242\n", "epoch 3 train 0.218 val 0.241 test 0.218\n", "epoch 4 train 0.211 val 0.220 test 0.205\n", "epoch 5 train 0.227 val 0.235 test 0.227\n", "epoch 6 train 0.219 val 0.243 test 0.231\n", "epoch 7 train 0.209 val 0.233 test 0.217\n", "epoch 8 train 0.215 val 0.210 test 0.201\n", "epoch 9 train 0.205 val 0.208 test 0.204\n", "epoch 10 train 0.262 val 0.279 test 0.249\n", "epoch 11 train 0.200 val 0.221 test 0.208\n", "epoch 12 train 0.208 val 0.237 test 0.212\n", "epoch 13 train 0.219 val 0.226 test 0.206\n", "epoch 14 train 0.209 val 0.222 test 0.216\n", "epoch 15 train 0.194 val 0.236 test 0.215\n", "epoch 16 train 0.204 val 0.212 test 0.206\n", "epoch 17 train 0.213 val 0.231 test 0.208\n", "epoch 18 train 0.202 val 0.232 test 0.211\n", "epoch 19 train 0.228 val 0.247 test 0.236\n" ] } ], "source": [ "# 惩罚项\n", "model = nn.Sequential(\n", " nn.Linear(784, 30, bias=False), nn.LayerNorm(30), nn.ReLU(), \n", " nn.Linear( 30, 20, bias=False), nn.LayerNorm(20), nn.ReLU(), \n", " nn.Linear( 20, 10)\n", ")\n", "\n", "_ = train_model(model, optim.Adam(model.parameters(), lr=0.01), epochs=20, penalty=True)" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'train': {'loss': 0.21390601992607117, 'acc': 0.9363999366760254},\n", " 'val': {'loss': 0.26335230469703674, 'acc': 0.9261999130249023},\n", " 'test': {'loss': 0.24580919742584229, 'acc': 0.928600013256073}}" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "estimate_loss(model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }