|
|
@@ -65,12 +65,13 @@
|
|
|
"source": [
|
|
|
"from torchvision import datasets, transforms\n",
|
|
|
"data_path = '../data-unversioned/p1ch6/'\n",
|
|
|
- "cifar10 = datasets.CIFAR10(data_path, train=True, download=True,\n",
|
|
|
- " transform=transforms.Compose([\n",
|
|
|
- " transforms.ToTensor(),\n",
|
|
|
- " transforms.Normalize((0.4915, 0.4823, 0.4468),\n",
|
|
|
- " (0.2470, 0.2435, 0.2616))\n",
|
|
|
- " ]))"
|
|
|
+ "cifar10 = datasets.CIFAR10(\n",
|
|
|
+ " data_path, train=True, download=True,\n",
|
|
|
+ " transform=transforms.Compose([\n",
|
|
|
+ " transforms.ToTensor(),\n",
|
|
|
+ " transforms.Normalize((0.4915, 0.4823, 0.4468),\n",
|
|
|
+ " (0.2470, 0.2435, 0.2616))\n",
|
|
|
+ " ]))"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
@@ -87,12 +88,13 @@
|
|
|
}
|
|
|
],
|
|
|
"source": [
|
|
|
- "cifar10_val = datasets.CIFAR10(data_path, train=False, download=True,\n",
|
|
|
- " transform=transforms.Compose([\n",
|
|
|
- " transforms.ToTensor(),\n",
|
|
|
- " transforms.Normalize((0.4915, 0.4823, 0.4468),\n",
|
|
|
- " (0.2470, 0.2435, 0.2616))\n",
|
|
|
- " ]))"
|
|
|
+ "cifar10_val = datasets.CIFAR10(\n",
|
|
|
+ " data_path, train=False, download=True,\n",
|
|
|
+ " transform=transforms.Compose([\n",
|
|
|
+ " transforms.ToTensor(),\n",
|
|
|
+ " transforms.Normalize((0.4915, 0.4823, 0.4468),\n",
|
|
|
+ " (0.2470, 0.2435, 0.2616))\n",
|
|
|
+ " ]))"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
@@ -103,8 +105,12 @@
|
|
|
"source": [
|
|
|
"label_map = {0: 0, 2: 1}\n",
|
|
|
"class_names = ['airplane', 'bird']\n",
|
|
|
- "cifar2 = [(img, label_map[label]) for img, label in cifar10 if label in [0, 2]]\n",
|
|
|
- "cifar2_val = [(img, label_map[label]) for img, label in cifar10_val if label in [0, 2]]"
|
|
|
+ "cifar2 = [(img, label_map[label])\n",
|
|
|
+ " for img, label in cifar10\n",
|
|
|
+ " if label in [0, 2]]\n",
|
|
|
+ "cifar2_val = [(img, label_map[label])\n",
|
|
|
+ " for img, label in cifar10_val\n",
|
|
|
+ " if label in [0, 2]]"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
@@ -140,7 +146,9 @@
|
|
|
}
|
|
|
],
|
|
|
"source": [
|
|
|
- "numel_list = [p.numel() for p in connected_model.parameters() if p.requires_grad == True]\n",
|
|
|
+ "numel_list = [p.numel()\n",
|
|
|
+ " for p in connected_model.parameters()\n",
|
|
|
+ " if p.requires_grad == True]\n",
|
|
|
"sum(numel_list), numel_list"
|
|
|
]
|
|
|
},
|
|
|
@@ -648,18 +656,23 @@
|
|
|
" for epoch in range(1, n_epochs + 1): # <2>\n",
|
|
|
" loss_train = 0.0\n",
|
|
|
" for imgs, labels in train_loader: # <3>\n",
|
|
|
+ " \n",
|
|
|
" outputs = model(imgs) # <4>\n",
|
|
|
+ " \n",
|
|
|
" loss = loss_fn(outputs, labels) # <5>\n",
|
|
|
"\n",
|
|
|
" optimizer.zero_grad() # <6>\n",
|
|
|
+ " \n",
|
|
|
" loss.backward() # <7>\n",
|
|
|
+ " \n",
|
|
|
" optimizer.step() # <8>\n",
|
|
|
"\n",
|
|
|
" loss_train += loss.item() # <9>\n",
|
|
|
"\n",
|
|
|
" if epoch == 1 or epoch % 10 == 0:\n",
|
|
|
" print('{} Epoch {}, Training loss {}'.format(\n",
|
|
|
- " datetime.datetime.now(), epoch, loss_train / len(train_loader))) # <10>"
|
|
|
+ " datetime.datetime.now(), epoch,\n",
|
|
|
+ " loss_train / len(train_loader))) # <10>"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
@@ -686,7 +699,8 @@
|
|
|
}
|
|
|
],
|
|
|
"source": [
|
|
|
- "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True) # <1>\n",
|
|
|
+ "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n",
|
|
|
+ " shuffle=True) # <1>\n",
|
|
|
"\n",
|
|
|
"model = Net() # <2>\n",
|
|
|
"optimizer = optim.SGD(model.parameters(), lr=1e-2) # <3>\n",
|
|
|
@@ -716,8 +730,10 @@
|
|
|
}
|
|
|
],
|
|
|
"source": [
|
|
|
- "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=False)\n",
|
|
|
- "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)\n",
|
|
|
+ "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n",
|
|
|
+ " shuffle=False)\n",
|
|
|
+ "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64,\n",
|
|
|
+ " shuffle=False)\n",
|
|
|
"\n",
|
|
|
"def validate(model, train_loader, val_loader):\n",
|
|
|
" for name, loader in [(\"train\", train_loader), (\"val\", val_loader)]:\n",
|
|
|
@@ -763,7 +779,8 @@
|
|
|
],
|
|
|
"source": [
|
|
|
"loaded_model = Net() # <1>\n",
|
|
|
- "loaded_model.load_state_dict(torch.load(data_path + 'birds_vs_airplanes.pt'))"
|
|
|
+ "loaded_model.load_state_dict(torch.load(data_path\n",
|
|
|
+ " + 'birds_vs_airplanes.pt'))"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
@@ -780,7 +797,8 @@
|
|
|
}
|
|
|
],
|
|
|
"source": [
|
|
|
- "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
|
|
|
+ "device = (torch.device('cuda') if torch.cuda.is_available()\n",
|
|
|
+ " else torch.device('cpu'))\n",
|
|
|
"print(f\"Training on device {device}.\")"
|
|
|
]
|
|
|
},
|
|
|
@@ -809,7 +827,8 @@
|
|
|
"\n",
|
|
|
" if epoch == 1 or epoch % 10 == 0:\n",
|
|
|
" print('{} Epoch {}, Training loss {}'.format(\n",
|
|
|
- " datetime.datetime.now(), epoch, loss_train / len(train_loader)))"
|
|
|
+ " datetime.datetime.now(), epoch,\n",
|
|
|
+ " loss_train / len(train_loader)))"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
@@ -836,7 +855,8 @@
|
|
|
}
|
|
|
],
|
|
|
"source": [
|
|
|
- "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)\n",
|
|
|
+ "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n",
|
|
|
+ " shuffle=True)\n",
|
|
|
"\n",
|
|
|
"model = Net().to(device=device) # <1>\n",
|
|
|
"optimizer = optim.SGD(model.parameters(), lr=1e-2)\n",
|
|
|
@@ -866,8 +886,10 @@
|
|
|
}
|
|
|
],
|
|
|
"source": [
|
|
|
- "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=False)\n",
|
|
|
- "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)\n",
|
|
|
+ "train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64,\n",
|
|
|
+ " shuffle=False)\n",
|
|
|
+ "val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64,\n",
|
|
|
+ " shuffle=False)\n",
|
|
|
"all_acc_dict = collections.OrderedDict()\n",
|
|
|
"\n",
|
|
|
"def validate(model, train_loader, val_loader):\n",
|
|
|
@@ -910,7 +932,9 @@
|
|
|
],
|
|
|
"source": [
|
|
|
"loaded_model = Net().to(device=device)\n",
|
|
|
- "loaded_model.load_state_dict(torch.load(data_path + 'birds_vs_airplanes.pt', map_location=device))"
|
|
|
+ "loaded_model.load_state_dict(torch.load(data_path\n",
|
|
|
+ " + 'birds_vs_airplanes.pt',\n",
|
|
|
+ " map_location=device))"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
@@ -998,7 +1022,8 @@
|
|
|
" super().__init__()\n",
|
|
|
" self.n_chans1 = n_chans1\n",
|
|
|
" self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)\n",
|
|
|
- " self.conv2 = nn.Conv2d(n_chans1, n_chans1 // 2, kernel_size=3, padding=1)\n",
|
|
|
+ " self.conv2 = nn.Conv2d(n_chans1, n_chans1 // 2, kernel_size=3,\n",
|
|
|
+ " padding=1)\n",
|
|
|
" self.fc1 = nn.Linear(8 * 8 * n_chans1 // 2, 32)\n",
|
|
|
" self.fc2 = nn.Linear(32, 2)\n",
|
|
|
" \n",
|
|
|
@@ -1078,7 +1103,8 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "def training_loop_l2reg(n_epochs, optimizer, model, loss_fn, train_loader):\n",
|
|
|
+ "def training_loop_l2reg(n_epochs, optimizer, model, loss_fn,\n",
|
|
|
+ " train_loader):\n",
|
|
|
" for epoch in range(1, n_epochs + 1):\n",
|
|
|
" loss_train = 0.0\n",
|
|
|
" for imgs, labels in train_loader:\n",
|
|
|
@@ -1088,7 +1114,8 @@
|
|
|
" loss = loss_fn(outputs, labels)\n",
|
|
|
"\n",
|
|
|
" l2_lambda = 0.001\n",
|
|
|
- " l2_norm = sum(p.pow(2.0).sum() for p in model.parameters()) # <1>\n",
|
|
|
+ " l2_norm = sum(p.pow(2.0).sum()\n",
|
|
|
+ " for p in model.parameters()) # <1>\n",
|
|
|
" loss = loss + l2_lambda * l2_norm\n",
|
|
|
"\n",
|
|
|
" optimizer.zero_grad()\n",
|
|
|
@@ -1098,7 +1125,8 @@
|
|
|
" loss_train += loss.item()\n",
|
|
|
" if epoch == 1 or epoch % 10 == 0:\n",
|
|
|
" print('{} Epoch {}, Training loss {}'.format(\n",
|
|
|
- " datetime.datetime.now(), epoch, loss_train / len(train_loader)))\n"
|
|
|
+ " datetime.datetime.now(), epoch,\n",
|
|
|
+ " loss_train / len(train_loader)))\n"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
@@ -1153,7 +1181,8 @@
|
|
|
" self.n_chans1 = n_chans1\n",
|
|
|
" self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)\n",
|
|
|
" self.conv1_dropout = nn.Dropout2d(p=0.4)\n",
|
|
|
- " self.conv2 = nn.Conv2d(n_chans1, n_chans1 // 2, kernel_size=3, padding=1)\n",
|
|
|
+ " self.conv2 = nn.Conv2d(n_chans1, n_chans1 // 2, kernel_size=3,\n",
|
|
|
+ " padding=1)\n",
|
|
|
" self.conv2_dropout = nn.Dropout2d(p=0.4)\n",
|
|
|
" self.fc1 = nn.Linear(8 * 8 * n_chans1 // 2, 32)\n",
|
|
|
" self.fc2 = nn.Linear(32, 2)\n",
|
|
|
@@ -1221,7 +1250,8 @@
|
|
|
" self.n_chans1 = n_chans1\n",
|
|
|
" self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)\n",
|
|
|
" self.conv1_batchnorm = nn.BatchNorm2d(num_features=n_chans1)\n",
|
|
|
- " self.conv2 = nn.Conv2d(n_chans1, n_chans1 // 2, kernel_size=3, padding=1)\n",
|
|
|
+ " self.conv2 = nn.Conv2d(n_chans1, n_chans1 // 2, kernel_size=3, \n",
|
|
|
+ " padding=1)\n",
|
|
|
" self.conv2_batchnorm = nn.BatchNorm2d(num_features=n_chans1 // 2)\n",
|
|
|
" self.fc1 = nn.Linear(8 * 8 * n_chans1 // 2, 32)\n",
|
|
|
" self.fc2 = nn.Linear(32, 2)\n",
|
|
|
@@ -1288,8 +1318,10 @@
|
|
|
" super().__init__()\n",
|
|
|
" self.n_chans1 = n_chans1\n",
|
|
|
" self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)\n",
|
|
|
- " self.conv2 = nn.Conv2d(n_chans1, n_chans1 // 2, kernel_size=3, padding=1)\n",
|
|
|
- " self.conv3 = nn.Conv2d(n_chans1 // 2, n_chans1 // 2, kernel_size=3, padding=1)\n",
|
|
|
+ " self.conv2 = nn.Conv2d(n_chans1, n_chans1 // 2, kernel_size=3,\n",
|
|
|
+ " padding=1)\n",
|
|
|
+ " self.conv3 = nn.Conv2d(n_chans1 // 2, n_chans1 // 2,\n",
|
|
|
+ " kernel_size=3, padding=1)\n",
|
|
|
" self.fc1 = nn.Linear(4 * 4 * n_chans1 // 2, 32)\n",
|
|
|
" self.fc2 = nn.Linear(32, 2)\n",
|
|
|
" \n",
|
|
|
@@ -1354,8 +1386,10 @@
|
|
|
" super().__init__()\n",
|
|
|
" self.n_chans1 = n_chans1\n",
|
|
|
" self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)\n",
|
|
|
- " self.conv2 = nn.Conv2d(n_chans1, n_chans1 // 2, kernel_size=3, padding=1)\n",
|
|
|
- " self.conv3 = nn.Conv2d(n_chans1 // 2, n_chans1 // 2, kernel_size=3, padding=1)\n",
|
|
|
+ " self.conv2 = nn.Conv2d(n_chans1, n_chans1 // 2, kernel_size=3,\n",
|
|
|
+ " padding=1)\n",
|
|
|
+ " self.conv3 = nn.Conv2d(n_chans1 // 2, n_chans1 // 2,\n",
|
|
|
+ " kernel_size=3, padding=1)\n",
|
|
|
" self.fc1 = nn.Linear(4 * 4 * n_chans1 // 2, 32)\n",
|
|
|
" self.fc2 = nn.Linear(32, 2)\n",
|
|
|
" \n",
|
|
|
@@ -1419,9 +1453,11 @@
|
|
|
"class ResBlock(nn.Module):\n",
|
|
|
" def __init__(self, n_chans):\n",
|
|
|
" super(ResBlock, self).__init__()\n",
|
|
|
- " self.conv = nn.Conv2d(n_chans, n_chans, kernel_size=3, padding=1, bias=False) # <1>\n",
|
|
|
+ " self.conv = nn.Conv2d(n_chans, n_chans, kernel_size=3,\n",
|
|
|
+ " padding=1, bias=False) # <1>\n",
|
|
|
" self.batch_norm = nn.BatchNorm2d(num_features=n_chans)\n",
|
|
|
- " torch.nn.init.kaiming_normal_(self.conv.weight, nonlinearity='relu') # <2>\n",
|
|
|
+ " torch.nn.init.kaiming_normal_(self.conv.weight,\n",
|
|
|
+ " nonlinearity='relu') # <2>\n",
|
|
|
" torch.nn.init.constant_(self.batch_norm.weight, 0.5)\n",
|
|
|
" torch.nn.init.zeros_(self.batch_norm.bias)\n",
|
|
|
"\n",
|
|
|
@@ -1443,7 +1479,8 @@
|
|
|
" super().__init__()\n",
|
|
|
" self.n_chans1 = n_chans1\n",
|
|
|
" self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)\n",
|
|
|
- " self.resblocks = nn.Sequential(* [ResBlock(n_chans=n_chans1)] * n_blocks)\n",
|
|
|
+ " self.resblocks = nn.Sequential(\n",
|
|
|
+ " *(n_blocks * [ResBlock(n_chans=n_chans1)]))\n",
|
|
|
" self.fc1 = nn.Linear(8 * 8 * n_chans1, 32)\n",
|
|
|
" self.fc2 = nn.Linear(32, 2)\n",
|
|
|
" \n",
|
|
|
@@ -1523,7 +1560,8 @@
|
|
|
"width =0.3\n",
|
|
|
"plt.bar(np.arange(len(trn_acc)), trn_acc, width=width, label='train')\n",
|
|
|
"plt.bar(np.arange(len(val_acc))+ width, val_acc, width=width, label='val')\n",
|
|
|
- "plt.xticks(np.arange(len(val_acc))+ width/2, list(all_acc_dict.keys()), rotation=60)\n",
|
|
|
+ "plt.xticks(np.arange(len(val_acc))+ width/2, list(all_acc_dict.keys()),\n",
|
|
|
+ " rotation=60)\n",
|
|
|
"plt.ylabel('accuracy')\n",
|
|
|
"plt.legend(loc='lower right')\n",
|
|
|
"plt.ylim(0.7, 1)\n",
|