{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n", "\u001b[0mRequirement already satisfied: torchvision in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (0.15.2)\n", "Requirement already satisfied: numpy in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torchvision) (1.24.4)\n", "Requirement already satisfied: requests in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torchvision) (2.24.0)\n", "Requirement already satisfied: torch==2.0.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torchvision) (2.0.1)\n", "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torchvision) (8.0.1)\n", "Requirement already satisfied: filelock in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch==2.0.1->torchvision) (3.0.12)\n", "Requirement already satisfied: typing-extensions in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch==2.0.1->torchvision) (4.8.0)\n", "Requirement already satisfied: sympy in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch==2.0.1->torchvision) (1.6.2)\n", "Requirement already satisfied: networkx in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch==2.0.1->torchvision) (2.5)\n", "Requirement already satisfied: jinja2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch==2.0.1->torchvision) (2.11.2)\n", "Requirement already satisfied: chardet<4,>=3.0.2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->torchvision) (3.0.4)\n", "Requirement already satisfied: idna<3,>=2.5 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->torchvision) (2.10)\n", "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->torchvision) (1.25.11)\n", "Requirement already satisfied: certifi>=2017.4.17 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->torchvision) (2020.6.20)\n", "Requirement already satisfied: MarkupSafe>=0.23 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from jinja2->torch==2.0.1->torchvision) (1.1.1)\n", "Requirement already satisfied: decorator>=4.3.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from networkx->torch==2.0.1->torchvision) (4.4.2)\n", "Requirement already satisfied: mpmath>=0.19 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from sympy->torch==2.0.1->torchvision) (1.1.0)\n", "\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33mDEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063\u001b[0m\u001b[33m\n", "\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" ] } ], "source": [ "!pip install torchvision" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader, random_split\n", "from torchvision import datasets\n", "import torchvision.transforms as transforms\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "torch.manual_seed(12046)\n", "\n", "dataset = datasets.MNIST(root='./mnist', train=True, download=True, transform=transforms.ToTensor())" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 28, 28])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x, y = dataset[21]\n", "x.shape" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/OQEPoAAAACXBIWXMAAAsTAAALEwEAmpwYAAAOCElEQVR4nO3da6hd9ZnH8d9vYosQRdQcY0jFUy8YZHBsOYSBBuOFKcY3SYNIfdEkIqZIxAs1jmaEegERnbaMMhRPR0kcO5ZiNYpIp5lQL31TcpRMEm8TRyIaojkhStQX6XjyzIuzLEc9+7+Oe+3byfP9wGHvvZ699nqyk1/WPuu/1v47IgTg6Pc3/W4AQG8QdiAJwg4kQdiBJAg7kMQxvdzYvHnzYnh4uJebBFLZs2ePDhw44OlqjcJu+1JJ/yJpjqR/i4h7S88fHh7W2NhYk00CKBgZGWlZa/tjvO05kv5V0jJJ50q60va57b4egO5q8jv7YklvRcTbEfEXSb+RtLwzbQHotCZhXyjp3SmP36uWfYHttbbHbI+Nj4832ByAJrp+ND4iRiNiJCJGhoaGur05AC00CfteSadNefytahmAAdQk7NsknW3727a/KemHkp7pTFsAOq3tobeI+Mz2dZL+U5NDb49ExKsd6wxARzUaZ4+I5yQ916FeAHQRp8sCSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kERPp2wGpnr++eeL9YsvvrhYj4i2X3/p0qXFdY9G7NmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnG2dFVGzdubFl74IEHiuvOmTOnWJ+YmCjWb7rpppa11atXF9ddt25dsX7MMbMvOo06tr1H0seSJiR9FhEjnWgKQOd14r+niyLiQAdeB0AX8Ts7kETTsIekP9h+2fba6Z5ge63tMdtj4+PjDTcHoF1Nw74kIr4raZmkdbYv+PITImI0IkYiYmRoaKjh5gC0q1HYI2Jvdbtf0lOSFneiKQCd13bYbc+1ffzn9yV9X9KuTjUGoLOaHI2fL+kp25+/zn9ExO870hVmjdI4uiQ9+uijLWs7d+7scDczf/2bb765uO6KFSuK9dNPP72dlvqq7bBHxNuS/q6DvQDoIobegCQIO5AEYQeSIOxAEoQdSGL2XaeHr+Wjjz4q1rdv316sX3XVVcV63SnQhw8fLtZLFi1aVKzXXeK6e/futrd9NGLPDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMM5+FNi8eXPL2ujoaHHdLVu2FOt1Y9l1X/fcxPr164v1I0eOFOvXXHNNJ9uZ9dizA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASjLPPAo899lixvmrVqq5tOyKK9bpx+G5uu043e5uN2LMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMsw+AunH0G264oVgvXVN+7LHHFtc95ZRTivVPPvmkWD948GCxXlLX2/HHH1+sHzp0qFjv5rX2s1Htnt32I7b32941ZdlJtrfY3l3dntjdNgE0NZOP8RslXfqlZbdK2hoRZ0vaWj0GMMBqwx4RL0r68me15ZI2Vfc3SVrR2bYAdFq7B+jmR8S+6v77kua3eqLttbbHbI/VzQsGoHsaH42PyasVWl6xEBGjETESESNDQ0NNNwegTe2G/QPbCySput3fuZYAdEO7YX9G0urq/mpJT3emHQDdUjvObvtxSRdKmmf7PUk/lXSvpN/avlrSO5Ku6GaTs13pe92l+uvRm4wXL168uFjfunVrsb5x48Zivcl3s99zzz3F+sqVK4v1ut7wRbVhj4grW5Qu6XAvALqI02WBJAg7kARhB5Ig7EAShB1IgktcO6BuCOjGG29s9Pp1l4KWhtcefPDBRtuuc9555xXra9asaVm79tprG2378ssvL9ZL01Vv27at0bZnI/bsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+wdcNdddxXrn376aaPX37BhQ7F+2223NXr9kiVLlhTry5YtK9bnz2/5jWWNHXfcccV63fkJ2bBnB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkGGefoe3bt7es1U1rPDExUawfOXKknZZ64qyzzup3C22bnKxoenV/J0cj9uxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATj7JVdu3YV66Xpgz/88MPiuk2mXEZrdec3HD58uGUt499J7Z7d9iO299veNWXZHbb32t5e/VzW3TYBNDWTj/EbJV06zfJfRMT51c9znW0LQKfVhj0iXpR0sAe9AOiiJgforrO9o/qYf2KrJ9lea3vM9tj4+HiDzQFoot2w/1LSmZLOl7RP0s9aPTEiRiNiJCJGhoaG2twcgKbaCntEfBARExFxRNKvJLWeRhTAQGgr7LYXTHn4A0nlcSsAfVc7zm77cUkXSppn+z1JP5V0oe3zJYWkPZJ+3L0We+P6668v1t99990edYKZeuKJJ4r1jHOwl9SGPSKunGbxw13oBUAXcboskARhB5Ig7EAShB1IgrADSXCJaw/cd999/W5hVnrjjTeK9VtuuaXt1x4eHi7Wj8bpntmzA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASjLP3wMknn9zvFgZS3Tj68uXLi/UDBw4U6/Pnz29Zq7s8trTubMWeHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYJy9EhHF+sTERNuvvWbNmmJ91apVbb92v9VNm1z6s23evLnRts8888xi/dlnn21ZO+eccxptezZizw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTDOXrn99tuL9R07drSsHTp0qNG2L7roomLddrFeuu67bjy57jvt684/OHz4cLFemjZ57ty5xXU3bNhQrK9cubJYzziWXlK7Z7d9mu0/2n7N9qu2b6iWn2R7i+3d1e2J3W8XQLtm8jH+M0k/iYhzJf29pHW2z5V0q6StEXG2pK3VYwADqjbsEbEvIl6p7n8s6XVJCyUtl7SpetomSSu61COADvhaB+hsD0v6jqQ/S5ofEfuq0vuSpv3SLttrbY/ZHhsfH2/SK4AGZhx228dJ+p2kGyPiC0ekYvIozrRHciJiNCJGImJkaGioUbMA2jejsNv+hiaD/uuIeLJa/IHtBVV9gaT93WkRQCfUDr15ctznYUmvR8TPp5SekbRa0r3V7dNd6bBHLrnkkmL9ySefbFmrGwKqG5p74YUXivU5c+YU6y+99FKx3kTdpb11vV1wwQUta6tXry6uO5sv/R1EMxln/56kH0naaXt7tWyDJkP+W9tXS3pH0hVd6RBAR9SGPSL+JKnVWR3l3SGAgcHpskAShB1IgrADSRB2IAnCDiTBJa4ztHTp0pa10uWvkjQ6Olqs33333W311AunnnpqsV4aR5ekhx56qGXthBNOaKsntIc9O5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kwTh7ByxcuLBYv/POO4v1M844o1i///77i/U333yzZW3RokXFddevX1+s1/W2ZMmSYh2Dgz07kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBOPsAqPv+9Lo6MBPs2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgidqw2z7N9h9tv2b7Vds3VMvvsL3X9vbq57LutwugXTM5qeYzST+JiFdsHy/pZdtbqtovIuKfu9cegE6Zyfzs+yTtq+5/bPt1SeWvZgEwcL7W7+y2hyV9R9Kfq0XX2d5h+xHbJ7ZYZ63tMdtj4+PjzboF0LYZh932cZJ+J+nGiDgk6ZeSzpR0vib3/D+bbr2IGI2IkYgYGRoaat4xgLbMKOy2v6HJoP86Ip6UpIj4ICImIuKIpF9JWty9NgE0NZOj8Zb0sKTXI+LnU5YvmPK0H0ja1fn2AHTKTI7Gf0/SjyTttL29WrZB0pW2z5cUkvZI+nEX+gPQITM5Gv8nSZ6m9Fzn2wHQLZxBByRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSMIR0buN2eOS3pmyaJ6kAz1r4OsZ1N4GtS+J3trVyd5Oj4hpv/+tp2H/ysbtsYgY6VsDBYPa26D2JdFbu3rVGx/jgSQIO5BEv8M+2uftlwxqb4Pal0Rv7epJb339nR1A7/R7zw6gRwg7kERfwm77Uttv2n7L9q396KEV23ts76ymoR7rcy+P2N5ve9eUZSfZ3mJ7d3U77Rx7feptIKbxLkwz3tf3rt/Tn/f8d3bbcyT9j6R/kPSepG2SroyI13raSAu290gaiYi+n4Bh+wJJn0h6NCL+tlp2n6SDEXFv9R/liRHxjwPS2x2SPun3NN7VbEULpk4zLmmFpDXq43tX6OsK9eB968eefbGktyLi7Yj4i6TfSFrehz4GXkS8KOnglxYvl7Spur9Jk/9Yeq5FbwMhIvZFxCvV/Y8lfT7NeF/fu0JfPdGPsC+U9O6Ux+9psOZ7D0l/sP2y7bX9bmYa8yNiX3X/fUnz+9nMNGqn8e6lL00zPjDvXTvTnzfFAbqvWhIR35W0TNK66uPqQIrJ38EGaex0RtN498o004z/VT/fu3anP2+qH2HfK+m0KY+/VS0bCBGxt7rdL+kpDd5U1B98PoNudbu/z/381SBN4z3dNOMagPeun9Of9yPs2ySdbfvbtr8p6YeSnulDH19he2514ES250r6vgZvKupnJK2u7q+W9HQfe/mCQZnGu9U04+rze9f36c8jouc/ki7T5BH5/5X0T/3ooUVfZ0j67+rn1X73JulxTX6s+z9NHtu4WtLJkrZK2i3pvySdNEC9/buknZJ2aDJYC/rU2xJNfkTfIWl79XNZv9+7Ql89ed84XRZIggN0QBKEHUiCsANJEHYgCcIOJEHYgSQIO5DE/wOeBksu2CwdCAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.imshow(x.squeeze(0).numpy(), cmap=plt.cm.binary)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "train_set, val_set = random_split(dataset, [50000, 10000])\n", "test_set = datasets.MNIST(root='./mnist', train=False, download=True, transform=transforms.ToTensor())" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "train_loader = DataLoader(train_set, batch_size=500, shuffle=True)\n", "val_loader = DataLoader(val_set, batch_size=500, shuffle=True)\n", "test_loader = DataLoader(test_set, batch_size=500, shuffle=True)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([500, 1, 28, 28]), torch.Size([500]), torch.Size([500, 784]))" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x, y = next(iter(train_loader))\n", "x.shape, y.shape, x.view(x.shape[0], -1).shape" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "class MLP(nn.Module):\n", " \n", " def __init__(self):\n", " pass\n", " \n", " def forward(self, x):\n", " pass\n", " \n", "model = MLP()\n", "\n", "model = nn.Sequential(\n", " nn.Linear(784, 30), nn.Sigmoid(),\n", " nn.Linear( 30, 20), nn.Sigmoid(),\n", " nn.Linear( 20, 10)\n", ")" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "eval_iters = 10\n", "\n", "\n", "def estimate_loss(model):\n", " re = {}\n", " # 将模型切换为评估模式\n", " model.eval()\n", " re['train'] = _loss(model, train_loader)\n", " re['val'] = _loss(model, val_loader)\n", " re['test'] = _loss(model, test_loader)\n", " # 将模型切换为训练模式\n", " model.train()\n", " return re\n", "\n", " \n", "@torch.no_grad()\n", "def _loss(model, dataloader):\n", " # 估算模型效果\n", " loss = []\n", " acc = []\n", " data_iter = iter(dataloader)\n", " for t in range(eval_iters):\n", " inputs, labels = next(data_iter)\n", " # inputs: (500, 1, 28, 28)\n", " # labels: (500)\n", " B, C, H, W = inputs.shape\n", " logits = model(inputs.view(B, -1))\n", " loss.append(F.cross_entropy(logits, labels))\n", " # preds = torch.argmax(F.softmax(logits, dim=-1), dim=-1)\n", " preds = torch.argmax(logits, dim=-1)\n", " acc.append((preds == labels).sum() / B)\n", " re = {\n", " 'loss': torch.tensor(loss).mean().item(),\n", " 'acc': torch.tensor(acc).mean().item()\n", " }\n", " return re" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([3, 3, 3, 2, 3, 1]), tensor([1, 1, 0, 2, 3, 3]))" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds = torch.randint(4, (6, ))\n", "labels = torch.randint(4, (6, ))\n", "preds, labels" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(2)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(preds == labels).sum()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'train': {'loss': 2.346519947052002, 'acc': 0.1096000075340271},\n", " 'val': {'loss': 2.345228672027588, 'acc': 0.10600000619888306},\n", " 'test': {'loss': 2.342650890350342, 'acc': 0.11500000953674316}}" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "estimate_loss(model)" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "def train_model(model, optimizer, epochs=10, penalty=False):\n", " lossi = []\n", " for e in range(epochs):\n", " for data in train_loader:\n", " inputs, labels = data\n", " B, C, H, W = inputs.shape\n", " logits = model(inputs.view(B, -1))\n", " loss = F.cross_entropy(logits, labels)\n", " lossi.append(loss.item())\n", " if penalty:\n", " w = torch.cat([p.view(-1) for p in model.parameters()])\n", " loss += 0.001 * w.abs().sum() + 0.002 * w.square().sum()\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " stats = estimate_loss(model)\n", " train_loss = f'{stats[\"train\"][\"loss\"]:.3f}'\n", " val_loss = f'{stats[\"val\"][\"loss\"]:.3f}'\n", " test_loss = f'{stats[\"test\"][\"loss\"]:.3f}'\n", " print(f'epoch {e} train {train_loss} val {val_loss} test {test_loss}')\n", " return lossi" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "loss = {}" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0 train 2.317 val 2.316 test 2.313\n", "epoch 1 train 2.306 val 2.304 test 2.304\n", "epoch 2 train 2.301 val 2.303 test 2.302\n", "epoch 3 train 2.300 val 2.300 test 2.301\n", "epoch 4 train 2.299 val 2.301 test 2.301\n", "epoch 5 train 2.299 val 2.300 test 2.299\n", "epoch 6 train 2.298 val 2.300 test 2.298\n", "epoch 7 train 2.298 val 2.299 test 2.298\n", "epoch 8 train 2.298 val 2.299 test 2.299\n", "epoch 9 train 2.296 val 2.297 test 2.297\n" ] } ], "source": [ "model = nn.Sequential(\n", " nn.Linear(784, 30), nn.Sigmoid(),\n", " nn.Linear( 30, 20), nn.Sigmoid(),\n", " nn.Linear( 20, 10)\n", ")\n", "\n", "loss['mlp'] = train_model(model, optim.SGD(model.parameters(), lr=0.01))" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0 train 2.286 val 2.287 test 2.287\n", "epoch 1 train 2.249 val 2.248 test 2.246\n", "epoch 2 train 2.191 val 2.189 test 2.187\n", "epoch 3 train 2.110 val 2.106 test 2.104\n", "epoch 4 train 1.976 val 1.974 test 1.971\n", "epoch 5 train 1.786 val 1.781 test 1.783\n", "epoch 6 train 1.539 val 1.526 test 1.524\n", "epoch 7 train 1.277 val 1.278 test 1.271\n", "epoch 8 train 1.072 val 1.072 test 1.076\n", "epoch 9 train 0.955 val 0.936 test 0.947\n" ] } ], "source": [ "model = nn.Sequential(\n", " nn.Linear(784, 30), nn.ReLU(),\n", " nn.Linear( 30, 20), nn.ReLU(),\n", " nn.Linear( 20, 10)\n", ")\n", "\n", "loss['mlp_relu'] = train_model(model, optim.SGD(model.parameters(), lr=0.01))" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0 train 1.449 val 1.457 test 1.450\n", "epoch 1 train 1.072 val 1.074 test 1.054\n", "epoch 2 train 0.801 val 0.803 test 0.790\n", "epoch 3 train 0.617 val 0.637 test 0.615\n", "epoch 4 train 0.511 val 0.514 test 0.502\n", "epoch 5 train 0.420 val 0.435 test 0.430\n", "epoch 6 train 0.370 val 0.397 test 0.369\n", "epoch 7 train 0.332 val 0.340 test 0.332\n", "epoch 8 train 0.317 val 0.319 test 0.308\n", "epoch 9 train 0.279 val 0.298 test 0.280\n" ] } ], "source": [ "model = nn.Sequential(\n", " nn.Linear(784, 30, bias=False), nn.LayerNorm(30), nn.ReLU(),\n", " nn.Linear( 30, 20, bias=False), nn.LayerNorm(20), nn.ReLU(),\n", " nn.Linear( 20, 10)\n", ")\n", "\n", "loss['mlp_relu_layer'] = train_model(model, optim.SGD(model.parameters(), lr=0.01))" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "for i in ['mlp', 'mlp_relu', 'mlp_relu_layer']:\n", " plt.plot(torch.tensor(loss[i]).view(-1, 10).mean(dim=-1), label=i)\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0 train 0.200 val 0.197 test 0.217\n", "epoch 1 train 0.129 val 0.157 test 0.147\n", "epoch 2 train 0.100 val 0.133 test 0.140\n", "epoch 3 train 0.088 val 0.142 test 0.125\n", "epoch 4 train 0.088 val 0.142 test 0.139\n", "epoch 5 train 0.075 val 0.137 test 0.132\n", "epoch 6 train 0.063 val 0.139 test 0.120\n", "epoch 7 train 0.059 val 0.118 test 0.119\n", "epoch 8 train 0.056 val 0.133 test 0.119\n", "epoch 9 train 0.055 val 0.118 test 0.125\n", "epoch 10 train 0.046 val 0.112 test 0.135\n", "epoch 11 train 0.042 val 0.128 test 0.122\n", "epoch 12 train 0.039 val 0.127 test 0.147\n", "epoch 13 train 0.044 val 0.115 test 0.126\n", "epoch 14 train 0.031 val 0.116 test 0.120\n", "epoch 15 train 0.031 val 0.134 test 0.133\n", "epoch 16 train 0.027 val 0.117 test 0.137\n", "epoch 17 train 0.024 val 0.135 test 0.117\n", "epoch 18 train 0.027 val 0.141 test 0.145\n", "epoch 19 train 0.027 val 0.133 test 0.178\n" ] } ], "source": [ "# 模型过拟合\n", "model = nn.Sequential(\n", " nn.Linear(784, 30, bias=False), nn.LayerNorm(30), nn.ReLU(),\n", " nn.Linear( 30, 20, bias=False), nn.LayerNorm(20), nn.ReLU(),\n", " nn.Linear( 20, 10)\n", ")\n", "\n", "loss['mlp_relu_layer'] = train_model(model, optim.Adam(model.parameters(), lr=0.01), epochs=20)" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "m = nn.Dropout(0.5)\n", "x = torch.randn(5, requires_grad=True)" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 1.3846, 0.5061, 0.3079, 1.1239, -1.9939], requires_grad=True)" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([0.0000, 1.0123, 0.0000, 0.0000, -0.0000], grad_fn=),\n", " tensor([0., 4., 2., 0., 0.]))" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m.train()\n", "l = m(x)\n", "l.sum().backward()\n", "l, x.grad" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 1.3846, 0.5061, 0.3079, 1.1239, -1.9939], requires_grad=True)" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m.eval()\n", "m(x)" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m.training" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0 train 0.265 val 0.260 test 0.261\n", "epoch 1 train 0.182 val 0.200 test 0.204\n", "epoch 2 train 0.153 val 0.187 test 0.196\n", "epoch 3 train 0.161 val 0.184 test 0.173\n", "epoch 4 train 0.179 val 0.178 test 0.170\n", "epoch 5 train 0.134 val 0.167 test 0.162\n", "epoch 6 train 0.122 val 0.171 test 0.164\n", "epoch 7 train 0.124 val 0.157 test 0.157\n", "epoch 8 train 0.116 val 0.160 test 0.149\n", "epoch 9 train 0.128 val 0.154 test 0.158\n", "epoch 10 train 0.104 val 0.153 test 0.147\n", "epoch 11 train 0.121 val 0.153 test 0.149\n", "epoch 12 train 0.117 val 0.142 test 0.158\n", "epoch 13 train 0.099 val 0.143 test 0.151\n", "epoch 14 train 0.110 val 0.159 test 0.139\n", "epoch 15 train 0.110 val 0.159 test 0.138\n", "epoch 16 train 0.105 val 0.138 test 0.151\n", "epoch 17 train 0.096 val 0.131 test 0.142\n", "epoch 18 train 0.089 val 0.153 test 0.156\n", "epoch 19 train 0.106 val 0.155 test 0.164\n" ] } ], "source": [ "model = nn.Sequential(\n", " nn.Linear(784, 30, bias=False), nn.LayerNorm(30), nn.ReLU(), nn.Dropout(0.2),\n", " nn.Linear( 30, 20, bias=False), nn.LayerNorm(20), nn.ReLU(), nn.Dropout(0.2),\n", " nn.Linear( 20, 10)\n", ")\n", "\n", "loss['mlp_relu_layer_dropout'] = train_model(model, optim.Adam(model.parameters(), lr=0.01), epochs=20)" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0 train 0.286 val 0.292 test 0.287\n", "epoch 1 train 0.256 val 0.253 test 0.235\n", "epoch 2 train 0.247 val 0.261 test 0.242\n", "epoch 3 train 0.218 val 0.241 test 0.218\n", "epoch 4 train 0.211 val 0.220 test 0.205\n", "epoch 5 train 0.227 val 0.235 test 0.227\n", "epoch 6 train 0.219 val 0.243 test 0.231\n", "epoch 7 train 0.209 val 0.233 test 0.217\n", "epoch 8 train 0.215 val 0.210 test 0.201\n", "epoch 9 train 0.205 val 0.208 test 0.204\n", "epoch 10 train 0.262 val 0.279 test 0.249\n", "epoch 11 train 0.200 val 0.221 test 0.208\n", "epoch 12 train 0.208 val 0.237 test 0.212\n", "epoch 13 train 0.219 val 0.226 test 0.206\n", "epoch 14 train 0.209 val 0.222 test 0.216\n", "epoch 15 train 0.194 val 0.236 test 0.215\n", "epoch 16 train 0.204 val 0.212 test 0.206\n", "epoch 17 train 0.213 val 0.231 test 0.208\n", "epoch 18 train 0.202 val 0.232 test 0.211\n", "epoch 19 train 0.228 val 0.247 test 0.236\n" ] } ], "source": [ "# 惩罚项\n", "model = nn.Sequential(\n", " nn.Linear(784, 30, bias=False), nn.LayerNorm(30), nn.ReLU(), \n", " nn.Linear( 30, 20, bias=False), nn.LayerNorm(20), nn.ReLU(), \n", " nn.Linear( 20, 10)\n", ")\n", "\n", "_ = train_model(model, optim.Adam(model.parameters(), lr=0.01), epochs=20, penalty=True)" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'train': {'loss': 0.21390601992607117, 'acc': 0.9363999366760254},\n", " 'val': {'loss': 0.26335230469703674, 'acc': 0.9261999130249023},\n", " 'test': {'loss': 0.24580919742584229, 'acc': 0.928600013256073}}" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "estimate_loss(model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }