|
|
@@ -125,6 +125,7 @@
|
|
|
}
|
|
|
],
|
|
|
"source": [
|
|
|
+ "# 有漏洞的残差连接\n",
|
|
|
"class ResidualBlockBugVersion(nn.Module):\n",
|
|
|
" \n",
|
|
|
" def __init__(self, in_channel, out_channel, stride=1):\n",
|
|
|
@@ -143,7 +144,7 @@
|
|
|
" out = F.relu(self.bn1(self.conv1(x)))\n",
|
|
|
" out = self.bn2(self.conv2(out))\n",
|
|
|
" # 残差连接\n",
|
|
|
- " ## 如果stride != 1 or in_channel != out_channel,\n",
|
|
|
+ " ## 如果stride != 1或者in_channel != out_channel,\n",
|
|
|
" ## 下面的计算会出错,因为out和inputs的形状不一样\n",
|
|
|
" out += inputs\n",
|
|
|
" out = F.relu(out)\n",
|
|
|
@@ -174,6 +175,14 @@
|
|
|
"class ResidualBlock(nn.Module):\n",
|
|
|
" \n",
|
|
|
" def __init__(self, in_channel, out_channel, stride=1):\n",
|
|
|
+ " '''\n",
|
|
|
+ " 定义残差块\n",
|
|
|
+ " 参数\n",
|
|
|
+ " ----\n",
|
|
|
+ " in_channel :int,输入通道\n",
|
|
|
+ " out_channel :int,输出通道\n",
|
|
|
+ " stride :int,步幅大小\n",
|
|
|
+ " '''\n",
|
|
|
" super().__init__()\n",
|
|
|
" self.conv1 = nn.Conv2d(\n",
|
|
|
" in_channel, out_channel, (3, 3), \n",
|
|
|
@@ -183,14 +192,24 @@
|
|
|
" out_channel, out_channel, (3, 3),\n",
|
|
|
" stride=1, padding=1, bias=False)\n",
|
|
|
" self.bn2 = nn.BatchNorm2d(out_channel)\n",
|
|
|
- " # 让输入的形状和输出的形状一样\n",
|
|
|
" self.downsample = None\n",
|
|
|
+ " # 如果stride != 1或者in_channel != out_channel,那么输入的形状和输出的形状不一样\n",
|
|
|
+ " # 使用下面的变换使得两个张量的形状一样\n",
|
|
|
" if stride != 1 or in_channel != out_channel:\n",
|
|
|
+ " # 下面两个卷积操作的输出形状是一样的\n",
|
|
|
+ " # Conv2d(in_channel, out_channel, (3, 3), stride, padding=1)\n",
|
|
|
+ " # Conv2d(in_channel, out_channel, (1, 1), stride, padding=0)\n",
|
|
|
" self.downsample = nn.Sequential(\n",
|
|
|
" nn.Conv2d(in_channel, out_channel, (1, 1), stride=stride, bias=False),\n",
|
|
|
" nn.BatchNorm2d(out_channel))\n",
|
|
|
" \n",
|
|
|
" def forward(self, x):\n",
|
|
|
+ " '''\n",
|
|
|
+ " 向前传播\n",
|
|
|
+ " 参数\n",
|
|
|
+ " ----\n",
|
|
|
+ " x :torch.FloatTensor,形状为(B, I, H, W)\n",
|
|
|
+ " '''\n",
|
|
|
" inputs = x\n",
|
|
|
" out = F.relu(self.bn1(self.conv1(x)))\n",
|
|
|
" out = self.bn2(self.conv2(out))\n",
|
|
|
@@ -224,13 +243,19 @@
|
|
|
" self.lm = nn.Linear(120, 10)\n",
|
|
|
"\n",
|
|
|
" def forward(self, x):\n",
|
|
|
+ " '''\n",
|
|
|
+ " 向前传播\n",
|
|
|
+ " 参数\n",
|
|
|
+ " ----\n",
|
|
|
+ " x :torch.FloatTensor,形状为(B, 1, 28, 28)\n",
|
|
|
+ " '''\n",
|
|
|
" x = self.block1(x) # (B, 20, 28, 28)\n",
|
|
|
" x = self.block2(x) # (B, 40, 14, 14)\n",
|
|
|
" x = self.block3(x) # (B, 60, 7, 7)\n",
|
|
|
" x = self.block4(x) # (B, 60, 4, 4)\n",
|
|
|
" x = self.block5(x) # (B, 60, 2, 2)\n",
|
|
|
" x = self.block6(x) # (B, 120, 1, 1)\n",
|
|
|
- " out = self.lm(x.view(x.shape[0], -1))\n",
|
|
|
+ " out = self.lm(x.view(x.shape[0], -1)) # (B, 10)\n",
|
|
|
" return out\n",
|
|
|
"\n",
|
|
|
"model = ResNet()"
|