|
|
@@ -363,7 +363,7 @@
|
|
|
"s = states[-1]\n",
|
|
|
"r = rewards[-1]\n",
|
|
|
"values = [model(i).item() for i in s]\n",
|
|
|
- "# 验证gamma=1时,gae等同于MC学习\n",
|
|
|
+ "# 验证lambda_=1时,gae等同于MC学习\n",
|
|
|
"mc_advantage = []\n",
|
|
|
"for i in range(len(r)):\n",
|
|
|
" G = compute_cum_rewards(r[i:], gamma)\n",
|
|
|
@@ -390,7 +390,7 @@
|
|
|
}
|
|
|
],
|
|
|
"source": [
|
|
|
- "# 验证gamma=0时,gae等同于TD学习\n",
|
|
|
+ "# 验证lambda_=0时,gae等同于TD学习\n",
|
|
|
"vt_next = values[:-1] + [0.0]\n",
|
|
|
"td_advantage = torch.tensor(r) + gamma * torch.tensor(vt_next) - torch.tensor(values)\n",
|
|
|
"gae = GAE(gamma, 0)\n",
|