Browse Source

fix test lstm bug

Gen TANG 1 năm trước cách đây
mục cha
commit
7741fb7cd0
2 tập tin đã thay đổi với 2 bổ sung6 xóa
  1. 1 3
      ch10_rnn/lstm.ipynb
  2. 1 3
      video/lstm.ipynb

+ 1 - 3
ch10_rnn/lstm.ipynb

@@ -177,8 +177,6 @@
     "    # 随机生成输入\n",
     "    inputs = torch.randn(B, T, input_size)\n",
     "    hs, cs = torch.randn((2 * num_layers, B, hidden_size)).chunk(2, 0)\n",
-    "    _hs = list((i.squeeze(0) for i in hs))\n",
-    "    _cs = list((i.squeeze(0) for i in cs))\n",
     "    re = inputs\n",
     "    # 取出模型参数\n",
     "    for layer_index in range(num_layers):\n",
@@ -199,7 +197,7 @@
     "        model.lstm.out_gate.weight = nn.Parameter(o)\n",
     "        model.lstm.out_gate.bias = nn.Parameter(ob)\n",
     "        # 计算隐藏状态\n",
-    "        re = model(re, (_hs[layer_index], _cs[layer_index]))\n",
+    "        re = model(re, (hs[layer_index], cs[layer_index]))\n",
     "    ref_re, _ = ref_model(inputs, (hs, cs))\n",
     "    # 验证计算结果(最后一层的隐藏状态是否一致)\n",
     "    out = torch.all(torch.abs(re - ref_re) < 1e-4)\n",

+ 1 - 3
video/lstm.ipynb

@@ -790,8 +790,6 @@
     "    # 随机生成输入\n",
     "    inputs = torch.randn(B, T, input_size)\n",
     "    hs, cs = torch.randn((2 * num_layers, B, hidden_size)).chunk(2, 0)\n",
-    "    _hs = list((i.squeeze(0) for i in hs))\n",
-    "    _cs = list((i.squeeze(0) for i in cs))\n",
     "    re = inputs\n",
     "    # 取出模型参数\n",
     "    for layer_index in range(num_layers):\n",
@@ -812,7 +810,7 @@
     "        model.cell.out_gate.weight = nn.Parameter(o)\n",
     "        model.cell.out_gate.bias = nn.Parameter(ob)\n",
     "        # 计算隐藏状态\n",
-    "        re = model(re, (_hs[layer_index], _cs[layer_index]))\n",
+    "        re = model(re, (hs[layer_index], cs[layer_index]))\n",
     "    ref_re, _ = ref_model(inputs, (hs, cs))\n",
     "    # 验证计算结果(最后一层的隐藏状态是否一致)\n",
     "    out = torch.all(torch.abs(re - ref_re) < 1e-4)\n",