|
|
@@ -0,0 +1,508 @@
|
|
|
+{
|
|
|
+ "cells": [
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 1,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "import torch\n",
|
|
|
+ "import torch.nn as nn"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 2,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def conv_example(in_channel, kernel):\n",
|
|
|
+ " # in_channel: (28, 28)\n",
|
|
|
+ " # kernel: ( 5, 5)\n",
|
|
|
+ " output = torch.zeros(24, 24)\n",
|
|
|
+ " for h in range(24):\n",
|
|
|
+ " for w in range(24):\n",
|
|
|
+ " inputs = in_channel[h: h + 5, w: w + 5]\n",
|
|
|
+ " output[h, w] = (inputs * kernel).sum()\n",
|
|
|
+ " return output"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 4,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "m = nn.Conv2d(1, 1, (5, 5), bias=False)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 5,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "x = torch.randn(1, 1, 28, 28)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 6,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "data": {
|
|
|
+ "text/plain": [
|
|
|
+ "torch.Size([1, 1, 24, 24])"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "execution_count": 6,
|
|
|
+ "metadata": {},
|
|
|
+ "output_type": "execute_result"
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "re = m(x)\n",
|
|
|
+ "re.shape"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 7,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "data": {
|
|
|
+ "text/plain": [
|
|
|
+ "torch.Size([1, 1, 5, 5])"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "execution_count": 7,
|
|
|
+ "metadata": {},
|
|
|
+ "output_type": "execute_result"
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "m.weight.shape"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 8,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "re1 = conv_example(x.squeeze(0, 1), m.weight.squeeze(0, 1))"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 9,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "data": {
|
|
|
+ "text/plain": [
|
|
|
+ "torch.Size([24, 24])"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "execution_count": 9,
|
|
|
+ "metadata": {},
|
|
|
+ "output_type": "execute_result"
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "re1.shape"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 12,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "data": {
|
|
|
+ "text/plain": [
|
|
|
+ "tensor(True)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "execution_count": 12,
|
|
|
+ "metadata": {},
|
|
|
+ "output_type": "execute_result"
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "torch.all((re - re1).abs() < 0.001)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 13,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "data": {
|
|
|
+ "text/plain": [
|
|
|
+ "torch.Size([4, 3, 5, 5])"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "execution_count": 13,
|
|
|
+ "metadata": {},
|
|
|
+ "output_type": "execute_result"
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "m1 = nn.Conv2d(3, 4, (5, 5))\n",
|
|
|
+ "m1.weight.shape"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 14,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "data": {
|
|
|
+ "text/plain": [
|
|
|
+ "torch.Size([10, 4, 24, 24])"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "execution_count": 14,
|
|
|
+ "metadata": {},
|
|
|
+ "output_type": "execute_result"
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "x1 = torch.randn(10, 3, 28, 28)\n",
|
|
|
+ "m1(x1).shape"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 15,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "data": {
|
|
|
+ "text/plain": [
|
|
|
+ "torch.Size([10, 3, 14, 14])"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "execution_count": 15,
|
|
|
+ "metadata": {},
|
|
|
+ "output_type": "execute_result"
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "p = nn.MaxPool2d(2, 2)\n",
|
|
|
+ "p(x1).shape"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 16,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "### 卷积神经网络的实现\n",
|
|
|
+ "import torch\n",
|
|
|
+ "import torch.nn as nn\n",
|
|
|
+ "import torch.nn.functional as F\n",
|
|
|
+ "import torch.optim as optim\n",
|
|
|
+ "from torch.utils.data import DataLoader, random_split\n",
|
|
|
+ "from torchvision import datasets\n",
|
|
|
+ "import torchvision.transforms as transforms\n",
|
|
|
+ "import matplotlib.pyplot as plt\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "torch.manual_seed(12046)\n",
|
|
|
+ "\n",
|
|
|
+ "dataset = datasets.MNIST(root='./mnist', train=True, download=True, transform=transforms.ToTensor())\n",
|
|
|
+ "train_set, val_set = random_split(dataset, [50000, 10000])\n",
|
|
|
+ "test_set = datasets.MNIST(root='./mnist', train=False, download=True, transform=transforms.ToTensor())\n",
|
|
|
+ "\n",
|
|
|
+ "train_loader = DataLoader(train_set, batch_size=500, shuffle=True)\n",
|
|
|
+ "val_loader = DataLoader(val_set, batch_size=500, shuffle=True)\n",
|
|
|
+ "test_loader = DataLoader(test_set, batch_size=500, shuffle=True)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 23,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "class CNN(nn.Module):\n",
|
|
|
+ " \n",
|
|
|
+ " def __init__(self):\n",
|
|
|
+ " super().__init__()\n",
|
|
|
+ " self.conv1 = nn.Conv2d(1, 20, (5, 5))\n",
|
|
|
+ " self.pool1 = nn.MaxPool2d(2, 2)\n",
|
|
|
+ " self.conv2 = nn.Conv2d(20, 40, (5, 5))\n",
|
|
|
+ " self.pool2 = nn.MaxPool2d(2, 2)\n",
|
|
|
+ " self.fc1 = nn.Linear(40 * 4 * 4, 120)\n",
|
|
|
+ " self.fc2 = nn.Linear(120, 10)\n",
|
|
|
+ " \n",
|
|
|
+ " def forward(self, x):\n",
|
|
|
+ " # x : (B, 1, 28, 28)\n",
|
|
|
+ " B = x.shape[0] # (B, 1, 28, 28)\n",
|
|
|
+ " x = F.relu(self.conv1(x)) # (B, 20, 24, 24)\n",
|
|
|
+ " x = self.pool1(x) # (B, 20, 12, 12)\n",
|
|
|
+ " x = F.relu(self.conv2(x)) # (B, 40, 8, 8)\n",
|
|
|
+ " x = self.pool2(x) # (B, 40, 4, 4)\n",
|
|
|
+ " x = F.relu(self.fc1(x.view(B, -1))) # (B, 120)\n",
|
|
|
+ " x = self.fc2(x) # (B, 10)\n",
|
|
|
+ " return x\n",
|
|
|
+ " \n",
|
|
|
+ "model = CNN()"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 24,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "eval_iters = 10\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "def estimate_loss(model):\n",
|
|
|
+ " re = {}\n",
|
|
|
+ " # 将模型切换为评估模式\n",
|
|
|
+ " model.eval()\n",
|
|
|
+ " re['train'] = _loss(model, train_loader)\n",
|
|
|
+ " re['val'] = _loss(model, val_loader)\n",
|
|
|
+ " re['test'] = _loss(model, test_loader)\n",
|
|
|
+ " # 将模型切换为训练模式\n",
|
|
|
+ " model.train()\n",
|
|
|
+ " return re\n",
|
|
|
+ "\n",
|
|
|
+ " \n",
|
|
|
+ "@torch.no_grad()\n",
|
|
|
+ "def _loss(model, dataloader):\n",
|
|
|
+ " # 估算模型效果\n",
|
|
|
+ " loss = []\n",
|
|
|
+ " acc = []\n",
|
|
|
+ " data_iter = iter(dataloader)\n",
|
|
|
+ " for t in range(eval_iters):\n",
|
|
|
+ " inputs, labels = next(data_iter)\n",
|
|
|
+ " # inputs: (500, 1, 28, 28)\n",
|
|
|
+ " # labels: (500)\n",
|
|
|
+ " B, C, H, W = inputs.shape\n",
|
|
|
+ " #logits = model(inputs.view(B, -1))\n",
|
|
|
+ " logits = model(inputs)\n",
|
|
|
+ " loss.append(F.cross_entropy(logits, labels))\n",
|
|
|
+ " # preds = torch.argmax(F.softmax(logits, dim=-1), dim=-1)\n",
|
|
|
+ " preds = torch.argmax(logits, dim=-1)\n",
|
|
|
+ " acc.append((preds == labels).sum() / B)\n",
|
|
|
+ " re = {\n",
|
|
|
+ " 'loss': torch.tensor(loss).mean().item(),\n",
|
|
|
+ " 'acc': torch.tensor(acc).mean().item()\n",
|
|
|
+ " }\n",
|
|
|
+ " return re"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 25,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def train_model(model, optimizer, epochs=10, penalty=False):\n",
|
|
|
+ " lossi = []\n",
|
|
|
+ " for e in range(epochs):\n",
|
|
|
+ " for data in train_loader:\n",
|
|
|
+ " inputs, labels = data\n",
|
|
|
+ " #B, C, H, W = inputs.shape\n",
|
|
|
+ " #logits = model(inputs.view(B, -1))\n",
|
|
|
+ " logits = model(inputs)\n",
|
|
|
+ " loss = F.cross_entropy(logits, labels)\n",
|
|
|
+ " lossi.append(loss.item())\n",
|
|
|
+ " if penalty:\n",
|
|
|
+ " w = torch.cat([p.view(-1) for p in model.parameters()])\n",
|
|
|
+ " loss += 0.001 * w.abs().sum() + 0.002 * w.square().sum()\n",
|
|
|
+ " optimizer.zero_grad()\n",
|
|
|
+ " loss.backward()\n",
|
|
|
+ " optimizer.step()\n",
|
|
|
+ " stats = estimate_loss(model)\n",
|
|
|
+ " train_loss = f'{stats[\"train\"][\"loss\"]:.3f}'\n",
|
|
|
+ " val_loss = f'{stats[\"val\"][\"loss\"]:.3f}'\n",
|
|
|
+ " test_loss = f'{stats[\"test\"][\"loss\"]:.3f}'\n",
|
|
|
+ " print(f'epoch {e} train {train_loss} val {val_loss} test {test_loss}')\n",
|
|
|
+ " return lossi"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 26,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "epoch 0 train 0.064 val 0.068 test 0.070\n",
|
|
|
+ "epoch 1 train 0.038 val 0.048 test 0.036\n",
|
|
|
+ "epoch 2 train 0.032 val 0.036 test 0.039\n",
|
|
|
+ "epoch 3 train 0.028 val 0.039 test 0.035\n",
|
|
|
+ "epoch 4 train 0.017 val 0.046 test 0.027\n",
|
|
|
+ "epoch 5 train 0.018 val 0.035 test 0.039\n",
|
|
|
+ "epoch 6 train 0.018 val 0.053 test 0.047\n",
|
|
|
+ "epoch 7 train 0.010 val 0.042 test 0.037\n",
|
|
|
+ "epoch 8 train 0.007 val 0.044 test 0.035\n",
|
|
|
+ "epoch 9 train 0.018 val 0.065 test 0.053\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "_ = train_model(model, optim.Adam(model.parameters(), lr=0.01))"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 27,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "data": {
|
|
|
+ "text/plain": [
|
|
|
+ "{'train': {'loss': 0.02370004542171955, 'acc': 0.9912000894546509},\n",
|
|
|
+ " 'val': {'loss': 0.05751935765147209, 'acc': 0.9878000020980835},\n",
|
|
|
+ " 'test': {'loss': 0.06482766568660736, 'acc': 0.984000027179718}}"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "execution_count": 27,
|
|
|
+ "metadata": {},
|
|
|
+ "output_type": "execute_result"
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "estimate_loss(model)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 29,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "class CNN2(nn.Module):\n",
|
|
|
+ " \n",
|
|
|
+ " def __init__(self):\n",
|
|
|
+ " super().__init__()\n",
|
|
|
+ " self.conv1 = nn.Conv2d(1, 20, (5, 5))\n",
|
|
|
+ " self.ln1 = nn.LayerNorm([20, 24, 24])\n",
|
|
|
+ " self.pool1 = nn.MaxPool2d(2, 2)\n",
|
|
|
+ " self.conv2 = nn.Conv2d(20, 40, (5, 5))\n",
|
|
|
+ " self.ln2 = nn.LayerNorm([40, 8, 8])\n",
|
|
|
+ " self.pool2 = nn.MaxPool2d(2, 2)\n",
|
|
|
+ " self.fc1 = nn.Linear(40 * 4 * 4, 120)\n",
|
|
|
+ " self.dp = nn.Dropout(0.2)\n",
|
|
|
+ " self.fc2 = nn.Linear(120, 10)\n",
|
|
|
+ " \n",
|
|
|
+ " def forward(self, x):\n",
|
|
|
+ " # x : (B, 1, 28, 28)\n",
|
|
|
+ " B = x.shape[0] # (B, 1, 28, 28)\n",
|
|
|
+ " x = F.relu(self.ln1(self.conv1(x))) # (B, 20, 24, 24)\n",
|
|
|
+ " x = self.pool1(x) # (B, 20, 12, 12)\n",
|
|
|
+ " x = F.relu(self.ln2(self.conv2(x))) # (B, 40, 8, 8)\n",
|
|
|
+ " x = self.pool2(x) # (B, 40, 4, 4)\n",
|
|
|
+ " x = F.relu(self.fc1(x.view(B, -1))) # (B, 120)\n",
|
|
|
+ " x = self.dp(x)\n",
|
|
|
+ " x = self.fc2(x) # (B, 10)\n",
|
|
|
+ " return x\n",
|
|
|
+ " \n",
|
|
|
+ "model2 = CNN2()"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 30,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "epoch 0 train 0.084 val 0.084 test 0.079\n",
|
|
|
+ "epoch 1 train 0.051 val 0.052 test 0.040\n",
|
|
|
+ "epoch 2 train 0.026 val 0.045 test 0.038\n",
|
|
|
+ "epoch 3 train 0.035 val 0.050 test 0.046\n",
|
|
|
+ "epoch 4 train 0.034 val 0.055 test 0.037\n",
|
|
|
+ "epoch 5 train 0.024 val 0.035 test 0.031\n",
|
|
|
+ "epoch 6 train 0.014 val 0.039 test 0.028\n",
|
|
|
+ "epoch 7 train 0.018 val 0.042 test 0.031\n",
|
|
|
+ "epoch 8 train 0.017 val 0.040 test 0.042\n",
|
|
|
+ "epoch 9 train 0.012 val 0.043 test 0.029\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "_ = train_model(model2, optim.Adam(model2.parameters(), lr=0.01))"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": []
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": []
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": []
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": []
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": []
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": []
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "metadata": {
|
|
|
+ "kernelspec": {
|
|
|
+ "display_name": "Python 3",
|
|
|
+ "language": "python",
|
|
|
+ "name": "python3"
|
|
|
+ },
|
|
|
+ "language_info": {
|
|
|
+ "codemirror_mode": {
|
|
|
+ "name": "ipython",
|
|
|
+ "version": 3
|
|
|
+ },
|
|
|
+ "file_extension": ".py",
|
|
|
+ "mimetype": "text/x-python",
|
|
|
+ "name": "python",
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
+ "version": "3.8.5"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "nbformat": 4,
|
|
|
+ "nbformat_minor": 4
|
|
|
+}
|