|
|
@@ -177,8 +177,6 @@
|
|
|
" # 随机生成输入\n",
|
|
|
" inputs = torch.randn(B, T, input_size)\n",
|
|
|
" hs, cs = torch.randn((2 * num_layers, B, hidden_size)).chunk(2, 0)\n",
|
|
|
- " _hs = list((i.squeeze(0) for i in hs))\n",
|
|
|
- " _cs = list((i.squeeze(0) for i in cs))\n",
|
|
|
" re = inputs\n",
|
|
|
" # 取出模型参数\n",
|
|
|
" for layer_index in range(num_layers):\n",
|
|
|
@@ -199,7 +197,7 @@
|
|
|
" model.lstm.out_gate.weight = nn.Parameter(o)\n",
|
|
|
" model.lstm.out_gate.bias = nn.Parameter(ob)\n",
|
|
|
" # 计算隐藏状态\n",
|
|
|
- " re = model(re, (_hs[layer_index], _cs[layer_index]))\n",
|
|
|
+ " re = model(re, (hs[layer_index], cs[layer_index]))\n",
|
|
|
" ref_re, _ = ref_model(inputs, (hs, cs))\n",
|
|
|
" # 验证计算结果(最后一层的隐藏状态是否一致)\n",
|
|
|
" out = torch.all(torch.abs(re - ref_re) < 1e-4)\n",
|