{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from utils import Lottery, plot_values, plot_action_probs\n", "\n", "\n", "torch.manual_seed(12046)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[tensor(10.0095),\n", " tensor(9.0089),\n", " tensor(8.0156),\n", " tensor(7.0169),\n", " tensor(6.0171),\n", " tensor(4.9980),\n", " tensor(4.0106),\n", " tensor(3.0101),\n", " tensor(2.0078),\n", " tensor(0.9982)]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def get_cum_rewards(r, gamma):\n", " '''\n", " 计算每一步的游戏得分并返回\n", " '''\n", " cum_rewards = []\n", " last_cum_reward = 0\n", " for j in reversed(r):\n", " last_cum_reward = j + gamma * last_cum_reward\n", " cum_rewards.insert(0, last_cum_reward)\n", " return cum_rewards\n", "\n", "get_cum_rewards(torch.normal(1, 0.01, (10,)), 1)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# 一些超参数\n", "gamma = 0.9\n", "learning_rate = 0.01\n", "grad_clip = 1.0" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class PolicyNet(nn.Module):\n", " \n", " def __init__(self):\n", " '''\n", " 策略学习\n", " '''\n", " super().__init__()\n", " self.emb = nn.Embedding(2, 4)\n", " self.ln = nn.Linear(4, 2)\n", "\n", " def forward(self, x):\n", " '''\n", " 向前传播\n", " 参数\n", " ----\n", " x :torch.LongTensor,游戏状态,形状为(G),其中G表示游戏步数\n", " 返回\n", " ----\n", " out :torch.FloatTensor,logits,形状为(G, 2)\n", " '''\n", " x = F.relu(self.emb(x))\n", " out = self.ln(x)\n", " return out\n", "\n", "# 定义游戏状态的数字表示\n", "tokenizer = {'w': 0, 'l': 1}" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "@torch.no_grad()\n", "def play_game(model, game):\n", " s = game.reset()\n", " done = False\n", " one_game_state = []\n", " one_game_reward = []\n", " one_game_action = []\n", " while not done:\n", " x = torch.tensor([tokenizer[s]]) # (1)\n", " logits = model(x) # (1, 2)\n", " probs = F.softmax(logits, dim=-1) # (1, 2)\n", " # 利用神经网络得到下一个行动\n", " action = torch.multinomial(probs, 1)\n", " next_s, r = game.step(action)\n", " # 记录游戏过程,分别是行动、状态和奖励\n", " one_game_action.append(action)\n", " one_game_state.append(s)\n", " one_game_reward.append(r)\n", " s = next_s\n", " if next_s == 't':\n", " done = True\n", " return one_game_state, one_game_action, one_game_reward" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(['l'], [tensor([[0]])], [0])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = PolicyNet()\n", "game = Lottery()\n", "play_game(model, game)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAHuCAYAAAAC6Q+WAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/OQEPoAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA1rElEQVR4nO3deXwkd33n/9dX50ijc27P5RlfY49tzNjGZBLuOwECG9gkSy6TQDaPJBACCQRvfpDNBljix7LhyuEQWPILISEJyRpsIDiBAD8YMDa+L2yPx3Nf0kgaaXT29/dHdWt6NFKrW+pWdatfz8ejH9VVXdX9KXVX91vf+lZViDEiSZLqV0PaBUiSpHQZBiRJqnOGAUmS6pxhQJKkOmcYkCSpzhkGJEmqc4YBSZLqXFPaBRQSQgjARmAo7VokSapBncChOM9Jhao6DJAEgQNpFyFJUg3bDBwsNEO1h4EhgP3799PV1ZV2LZIk1YzBwUG2bNkCRbSuV3sYAKCrq8swIElShdiBUJKkOmcYkCSpzhkGJEmqc4YBSZLqnGFAkqQ6ZxiQJKnOGQYkSapzhgFJkuqcYUCSpDpnGJAkqc4ZBiRJqnOGAUmS6pxhQJKkOmcYkCSpzpUcBkIIa0IIe0MI24qc//khhIdDCCdCCG8vuUJJklRRJYWBEMIa4IvAtiLnXwvcCnwW2A38XAjhhSXWKEmSKqipxPn/Dvhb4NlFzv9zwCHgf8QYYwjhD4FfAb5W4utKZRVjzA4hZsfj9Hgk+/D0eO5+wecs4XUX9xzzzVDEc8wz07yvUdzLLMn6zrcuxbzQUtVRjr/rfH/TYl9nsc9RjvUtx2eomOcp7u9R+W2imOfY0LWC7vbm+Wcss1LDwJtjjHtDCB8ucv5rgK/Fs+/m94D/OdfMIYRWoDVvUmeJ9dWdTCZydGiUvuFxTo9OMjw+ydDoJKMTU4xNZhifzEwPx6eS4VQmMpnJMDkVmczE7HhkKpNhYioZz8RIJmZ/JCNk8ock0zN50/PHM9Pj586TyU5L5p8xTm5DyR8/+3WTe43sLIV/wKfnOf+5yvElKUmV8sevfwY/ff2WJX/dksJAjHFvic/fBTyUNz4IbCww/7uB95b4GnVhcirDvQcGePzYEHtPjPDUiWH2nhhmX98woxOZtMuTllwIRcwz73MUnmP+5YupYZ6ZFvdwUXXMV0M1/C2LWdFF11AN6zlPHa1N6fTrL7VloFSTwFje+CjQXmD+DwAfyhvvBA5UoK6aMDAywdcfO8a/PXyMrz96jMHRyVnna2oIrFrZQkdrEx0rmljZ0kR7SyMtTQ3JrbHh7P2mBpobGmhsCDQ3BhobGmhqCOeNhwANIdDQkHyRhJBsBA0hGW8InJ1nxnzT0/KWaWjIH89Oy24RIeS/Rt797Pols52dFkLIDs9+yeU2rlydsz2efZpznj83b34dBGZ9rWwVBd+zYr5sFrv8Ymso7stqCX4kF/vHklQ2lQ4DfcDavPFOYHyumWOMY+SFh3r8sjh46gxfuv8wdzx8lDuf6mcqc7Zdu6e9mas3dXPRmpVsy962r17J5t42mho9SlSStDCVDgN3Am/IG98FHKzwa9akvSeGef/tD3PHw0fP2a+9Y30nL7piHS+5Yh3P3NJLY0P9BSRJUmWVJQyEELqAMzHGiRkP3Qp8PITwEuA/gHcCXynHay4XAyMT/Nl/PMFffetJJqaSFPDs7at4+ZUbeOnO9WxZVWiviiRJi1euloH7gLcB/5I/McZ4IoTw28DtwGngFHBjmV6zph0bHOUT39rLZ/bsY3h8CoDnXbaW97zqCi5Z50EUkqSls6AwEGMMM8a3FZj3z0MIXwEuB74ZYzy9kNdcLvaeGOaWbzzJP911gPGp5CiAHes7+Z2X7+AlV6yry34SkqR0VbrPADB9SGKphyUuKw8cHODjX3ucLz94ZLpPwPUX9vLrL7yYF+4wBEiS0rMkYaCejU5M8aGvPsYnvvkkuQMDXnz5On7tBRfzrG2r0i1OkiQMAxX13SdP8q5/uo+nTo4A8KpnXMBbXnQpOzbYJ0CSVD0MAxXy6W8/xXtvfRBIzjX9R6+9ipfsXJ9yVZIknc8wUGYTUxlu+caTfOirjwHw+us2855X76RrxdJfeEKSpGIYBsrsHZ+7l1vvPQTAzz5rCx/4qavtHChJqmqGgTL62iPHuPXeQzQ2BD74umfwums3GQQkSVXPMFAmZ8aneM+tDwDwK8/Zzuuv25xyRZIkFcer25TJX37zSfb3nWFj9wp+68WXpl2OJElFMwyUQYyRf7hrPwC/+4odrGy1wUWSVDsMA2XwwMFB9vedoa25kVdceUHa5UiSVBLDQBl88f7k6IEXXb6OtpbGlKuRJKk0hoFFOnl6jL/d8zQAr75mY8rVSJJUOsPAIn38a08wNDbJVZu6eJlnGJQk1SDDwCL0DY/zme/uA+CdL7+chgbPKSBJqj2GgUXY8+RJxiYzXLa+g+deuibtciRJWhDDwCLcva8fgBu2r/JMg5KkmmUYWIS7nk7CwHUX9qZciSRJC2cYWKCxySkePDgIwLVbDQOSpNplGFigBw4OMj6VYfXKFrauak+7HEmSFswwsEA/yO4i2LW11/4CkqSaZhhYoB88fQqAay/sSbUOSZIWyzCwQA8cGgDgGZt60i1EkqRFMgwswODoBPtOjgBw5caulKuRJGlxDAML8PCh5CiCjd0r6F3ZknI1kiQtjmFgAR46nISBnRu7U65EkqTFMwwswGNHhwC44oLOlCuRJGnxDAML8Pix0wBcsq4j5UokSVo8w8ACPHF8GICL1xoGJEm1zzBQor7hcfqGxwHDgCRpeTAMlOiH2f4Cm3vbaGtpTLkaSZIWzzBQosey/QUuW2/nQUnS8mAYKFGuZeDS9e4ikCQtD4aBEv3waLZlYJ0tA5Kk5cEwUKKDp84AcOFqL1ssSVoeDAMliDFyZHAUgPVdK1KuRpKk8jAMlKB/ZILxyQxgGJAkLR+GgRIcHkh2EazpaKGlyT+dJGl58BetBEcG3EUgSVp+DAMlyPUX2GAYkCQtI4aBEhw+lYSBC3oMA5Kk5cMwUIJD2cMKL+huS7kSSZLKxzBQgkPZDoSbegwDkqTlwzBQgsPZDoQXdLubQJK0fBgGipTJxOkwsNGWAUnSMmIYKNLJ4XHGJzOE4KGFkqTlxTBQpNwJh9Z2tHrCIUnSsuKvWpEOTR9W6C4CSdLyYhgoUq5lYKOdByVJy4xhoEgH+7NhwJYBSdIyYxgo0oFsGNjSaxiQJC0vhoEiHTg1AsDm3vaUK5EkqbwMA0XKtQxsXmXLgCRpeTEMFGFodIJTIxOApyKWJC0/hoEi5FoFetqb6VzRnHI1kiSVl2GgCNO7COw8KElahgwDRTjQn+082GPnQUnS8mMYKMITx08DsHW1YUCStPwYBorwwMFBAK7e1J1yJZIklZ9hoAi53QTb16xMuRJJksrPMDCPM+NTnDg9DtiBUJK0PBkG5pFrFehsbaK7zcMKJUnLj2FgHrnDCjf1thFCSLkaSZLKzzAwj/3ZloEtqzySQJK0PBkG5uEJhyRJy51hYB65PgNbvFqhJGmZMgzMY3+fLQOSpOXNMDCPA/YZkCQtc4aBAk6PTdKfvXSxLQOSpOXKMFDA/r6kVcBLF0uSljPDQAH3HxwA4NJ1HSlXIklS5RgGCrhn/ykArt3am24hkiRVUElhIIRwVQjhzhBCfwjh5jDPKflC4s9CCH0hhFMhhP8TQqiZne/3HTgFwDM296RahyRJlVR0GAghtAJfAO4Crgd2AjfOs9gvADuAXcBzgSuBdy+k0KU2OjHFI4eHALhmi5culiQtX6W0DPw40A28Pcb4BHAT8CvzLHMD8I8xxn0xxvuBfwEuWUihS+3hw4NMZiKrV7awqadmGjMkSSpZKWHgGmBPjHEkO34fSetAIQ8CPx9CWB9CuBD4WeCrc80cQmgNIXTlbkBnCfWV1X0Hks6Dz9jc7QWKJEnLWilhoAvYmxuJMUZgKoRQqHfdJ4AO4AjwVHb5TxeY/93AQN7tQAn1ldW99heQJNWJUsLAJDA2Y9ooUOjUfL8FnAIuBLYCTcDNBeb/AMmuiNxtcwn1lVWuZcD+ApKk5a6UMNAHrJ0xrRMYL7DMzwE3xxifjjHuJ/nPf85+BjHGsRjjYO4GDJVQX9mMTU6x98QwAFduNAxIkpa3UsLAncDu3EgIYTvQShISCj3/urzxDUBjKQWmYd/JEaYykc7WJtZ1tqZdjiRJFdVUwrzfALpCCG+MMX6K5GiCO2KMUyGEHmAoxjg1Y5lvAr8XQpgCWoB3AbeWoe6KevzYaQAuWtdh50FJ0rJXdBiIMU6GEN4EfDaEcDOQAV6Qfbif5FwC98xY7PdJOh7+Mckuha+Q9COoak9kw8DFa1emXIkkSZVXSssAMcZbQwgXA9eRHGZ4Mjt91n+fY4yngF9cbJFLbV/2AkXbVxsGJEnLX0lhACDGeAS4rQK1VI0D/UkY2LKq0IESkiQtD16oaBYH+s8AsLnXMw9KkpY/w8AMk1MZDg+MArC515YBSdLyZxiY4cjgKFOZSHNj8LBCSVJdMAzMkNtFsKmnjYYGDyuUJC1/hoEZnjyenHlwk/0FJEl1wjAww+33Hwbg2dtXp1yJJElLwzCQ59jQKN9+4gQAr33mppSrkSRpaRgG8ty9r59MhJ0XdLF1tUcSSJLqg2EgK8bIlx84AsDOjV0pVyNJ0tIxDGT9zXef5l/uOQTAZes7Uq5GkqSlYxjIuu2+Q9P3X3T5+hQrkSRpaRkGgNGJKe7edwqAT//yDVyyzpYBSVL9MAwAPzx6mvGpDGs6WnjepWvSLkeSpCVlGODsVQq3rmonBM86KEmqL4YB8q9S6OGEkqT6YxjgbMuAlyyWJNUjwwCw35YBSVIdq/swEGPk0SNDAGxbYxiQJNWfug8DTxw/zcFTZ2hpbOAZm3vSLkeSpCXXlHYBaRmfzPDNHx7nkWyrwDO39NDRWrd/DklSHavbX79f/8zd3PHw0enxKzd5PQJJUn2qy90Ep8cmzwkCAK+7dnNK1UiSlK66bBl48ODA9P1f3H0hv/q8izySQJJUt+oyDNyfDQMv3bmeP3zNVSlXI0lSuupyN8FjR5NOg1dutJ+AJEl1GQYOnRoFYIu7BiRJqtcwkJxxcGOPpx+WJKnuwkCMkYPZMLDJMCBJUv2Fgf6RCcYmMwCs725NuRpJktJXd2Egt4tgbWcrrU2NKVcjSVL66i4MHLS/gCRJ56i7MHBour/AipQrkSSpOtTdSYe2rV7Ja5+5kWsv7E27FEmSqkKIMaZdw5xCCF3AwMDAAF1dniBIkqRiDQ4O0t3dDdAdYxwsNG/d7SaQJEnnMgxIklTnDAOSJNU5w4AkSXXOMCBJUp0zDEiSVOcMA5Ik1TnDgCRJdc4wIElSnTMMSJJU5wwDkiTVOcOAJEl1zjAgSVKdMwxIklTnDAOSJNU5w4AkSXXOMCBJUp0zDEiSVOcMA5Ik1TnDgCRJdc4wIElSnTMMSJJU5wwDkiTVOcOAJEl1zjAgSVKdMwxIklTnDAOSJNU5w4AkSXXOMCBJUp0zDEiSVOcMA5Ik1TnDgCRJdc4wIElSnTMMSJJU5+o3DMQIk+NpVyFJUurqMwwcvhc+tBM+ei2cOZV2NZIkpar+wsC9fwd/8TwYOgQD++HJr6ddkSRJqSopDIQQrgoh3BlC6A8h3BxCCEUu1xBC+HYI4R0LK7OMxofPHT/+aDp1SJJUJYoOAyGEVuALwF3A9cBO4MYiF/81oBv4SIn1ld+lL4Wf/VvY/ZvJ+AnDgCSpvjWVMO+Pk/ygvz3GOBJCuAn4OPCpQguFEDYC7wd+KsY4Mc+8rUBr3qTOEuorTs/W5BYa4DsfgxOPlf0lJKlaxRiZmJhgcnIy7VJUQGNjIy0tLRTZAL9opYSBa4A9McaR7Ph9JK0D8/kTYB+wJYTwozHGbxeY993Ae0uoaeHWXJYMTzwOmQw01F/3CUn1I8bI8ePHOX78OKOjo2mXoyI0NjbS09PDmjVr6OjoqOhrlRIGuoC9uZEYYwwhTIUQemOM/bMtEELYDfxn4HbgYuD3QwhfiTH+5hyv8QHgQ3njncCBEmosXs+F0NgCk2eSjoS9F1bkZSSpGuzfv5/jx4/T09PDxo0bl/S/TpUmxsjU1BRDQ0P09/dz8uRJLrroInp7eyv2mqWEgUlgbMa0UaAdmDUMAG8Gvgu8Khse/hLYF0L4aIzxvJ31Mcax/Neo6Ae1sQlWXQzHH052FRgGJC1TfX19HD9+nAsvvJA1a9akXY6K1NXVxcaNG9m7dy9PPvkkO3bsqFgLQSlt433A2hnTOoFCZ+7ZDNweY4wAMcb9wHGSVoL0rc3uKvCIAknLWF9fHx0dHQaBGhRCYPv27TQ3N3PvvfcyPl6Zk+WVEgbuBHbnRkII20k6+/UVWOYA0Ja3TAewCjhYWpkVsmZHMjz2cLp1SFKFZDIZhoaG6O7uTrsULVAIgdWrV9PU1MQdd9zBxETBvvgLUkoY+AbQFUJ4Y3b8JuCOGONUCKEnhNA4yzKfBd4cQnhxCOFC4E+BR0g6H6Zv07XJ8MD30q1DkipkfHycTCZDe3t72qVoETo7O2lqauLxxx/nqaeeKvvzFx0GYoyTwJuAj4UQTgCvAd6VfbgfuHqWZb6anefPSELApcDrc7sNUrf5hmR44jEYKdTAIUm1KZPJAEnPdNWu3PvX0NCQbhgAiDHeSrK//5eAK2KMD2WnhxjjPXMs81cxxstijG0xxt2zdRxMzcrVsPrS5P4fb4fbfifdeiSpQjxyoLbl3r+2tjb27dtHuf+nLvng+hjjkRjjbTHGk2WtJC3bn3v2/p1/CUNH0qtFkqQCGhsbK3LSKM+087x3QlvesZuP35FeLZIkFRBCKHurABgGoOsCuPH2s4HgwJ3p1iNJ0hIzDACs3wkv+YPk/uDhVEuRJGmpGQZyOi9IhkOGAUlSfTEM5BgGJEl1yjCQ07UxGQ4fh6nyn91JkqRqZRjIaV+dXMUQPLxQklRXDAM5IZxtHTj1dLq1SJK0hAwD+dZkr2J46Afp1iFJ0hIyDOTb9pxkuH9PunVIkrSEmtIuoKqsvTwZ9j2VahmStNRijJyZmEq7jIpra26syHUadu/ezZ49e3jf+97HTTfdBMDb3vY2PvzhD9PU1MTQ0BArVqzgzJkzdHR0kMlk2LNnD89+9rPLXstCGAbyrbooGfY9CTEm/QgkqQ6cmZhi53u+knYZFffQH76c9pby//TlwsD9998/Pe0HP0h2OU9OTvLAAw9w/fXX89BDD5HJZFixYgXXXntt2etYKHcT5OvZCgSYGIbTx9KuRpJUI370R38UYDoMxBi59957aWhIfmZzwSD3+PXXX09zc3MKlc7OloF8Ta3QvRkG9sOpfdC5Pu2KJGlJtDU38tAfvjztMiqurbmxIs+bCwOPPvoo4+PjHDhwgIGBAd7ylrfw0Y9+lHvuuQc4GwZ2795dkToWyjAwU++2JAz0PwVbbki7GklaEiGEijSf14uNGzeydetWnn76aR555BEef/xxAF784hdz6623ntcykAsP1cLdBDP1XpgM+/amW4ckqabk7yrI/fjv2rWLXbt2cd9995HJZAwDNaN3ezI88Vi6dUiSakp+GLjnnntYtWoVW7duZdeuXQwPD/Od73yHI0eOcNFFF7Fu3bqUqz2XYWCmzc9Khj/8KkxNpluLJKlm5PoB5FoGdu3aBTA9/Ou//mug+loFwDBwvm3PgRU9MDYAB+9KuxpJUo145jOfSXt7O7fffjsHDx48LwzccsstgGGgNjQ0wiUvSe7f+9l0a5Ek1Yympiae9axnTY/nQsDmzZtZs2bN9PRqO5IADAOzu+7GZHjf5+BMf6qlSJJqR/4PfS4M5N/v7Ozk6quvXvK65mMYmM2258C6K5OTD3361cnZCCVJmkduF0B7ezs7duyYnp4LAzfccAONjZU518FiGAZmEwK84gPJ/SP3w+ChdOuRJNWEV7/61cQYGR4enj77IMAHP/hBYozccccdKVY3N8PAXC56PrT1Jvfv+QwMHk7u778T/vdV8O2PpVebJEllZBgo5IqfTIZfex/8zU8l92/77eQMhf/639x9IElaFgwDheSfjvjYQzA+fO7lje/5zJKXJElSuRkGCrn8lbBu59nxvd+A8aGz4w98fulrkiSpzAwDhbT1wq9/BzZdl4zf97lk2JC97OT+73qWQklSzTMMFGPVRcnwwWxLwLW/kJylcPw0fP399h2QJNU0w0Ax1u44d3zzDXDNf0nuf/N/wfc/ufQ1SZJUJoaBYlzy0nPHL9wNz3nb2fGH/u+SliNJUjkZBopxwTXQszW5v/s3oXcbdG6A//rNZNqR+1MrTZKkxWpKu4CaEAL84q0weDA5VXFO18ZkeKYPpiagsTmd+iRJWgRbBoq1avu5QQCgbRWE7Dmmh48vfU2SJJWBYWAxGhpg5drk/ulj6dYiSdICGQYWqyMbBmwZkCTVKMPAYq1clwxtGZAk1SjDwGJ1ZMPAsGFAklSbDAOLZZ8BSVKNMwwsVsf6ZGgYkCTVKMPAYk2HgaPp1iFJqipf//rXCSGwbdu2tEuZl2FgsXJnJux/KtUyJElaKMPAYq3angwHDsDkWLq1SJK0AIaBxVq5FppXAhFOPZ12NZIklcwwsFghJBcuAujbm2opkiQthGGgHNbuSIYH70q3DkmSFsCrFpbDxS+CBz8PD38BXvB7SWuBJNWSGGFiJO0qKq+53e/oWRgGyuHyV8LtvwPHHoSjD8KGq9KuSJJKMzEC79+YdhWVd9MhaFmZdhVVx90E5dC+CtZcltwfPJhuLZIklciWgXJZuSYZjpxMtw5JWojm9uS/5uWuuT3tCqqSYaBc2g0DkmpYCDaf1zF3E5RL++pkaBiQJNUYw0C55MLA8Il065AkqUSGgXJZmWsZ6Eu3DkmSSmQYKBd3E0iSapRhoFymw4C7CSRJtcUwUC4dG5Lh0NF065AkqUSGgXLpXJ8Mx4dgbCjdWiRJKoHnGSiX1k5o6UzCwNDRZFySVLde8IIXEGNMu4yi2DJQTp25XQWH061DkqQSGAbKqeuCZGgYkCTVEMNAOXUaBiRJtccwUE7TuwmOpFuHJEklMAyUU65lYLAOrvwlSVo2DAPl1L0lGZ7al24dkiSVwDBQTr3bkmG/YUCSVDsMA+XUvioZjg5AjRxbKkmSYaCcVvQkwzjlWQglVZVaOfmNZlfp988wUE7NbdDYktwfHUi3FkkCGhsbAZicnEy5Ei3G+Pg4ULn30TBQTiGcbR0YPZVmJZIEQEtLCy0tLQwM+A9KLTt58iTj4+OMjY1V5PlLDgMhhKtCCHeGEPpDCDeHEEIJy/aEEA6HELaV+ro1o60nGZ45lWYVkgRACIHe3l5OnjzJ8PBw2uVoAY4fP87AwAD9/f1MTU3R1NREU1N5Ly1U0rOFEFqBLwBfAX4W+AhwI/CpIp/iZmBDKa9Zc2wZkFRlLrjgAk6fPs2jjz7K6tWr6enpoampiRL+l9MSijEyNTXFyMgIp06dYnh4mP7+/un7l1xySdnfu1KjxY8D3cDbY4wjIYSbgI9TRBgIITwP+EngZMlV1hJbBiRVmcbGRi699FLuvvtuDh48yIkTJ9IuSUWIMTI8PMzg4CBDQ0PTuwi2b99e9tcqNQxcA+yJMY5kx+8Dds63ULZF4S+AtwIfnGe+1rxJtXcd4BXdydAOhJKqSGNjI5dddhlf+MIXOHbsGBs2bKClpcXWgSqVyWSYnJwkxkiMkZGREY4fP87ll19eFWGgC9ibG4kxxhDCVAihN8bYX2C5m4DHYox/H0KYMwwA7wbeW2JN1cXdBJKqVHd3N6985Sv50pe+xOHDhxkbG6OhocFAUMUymQwAbW1t7Ny5k5e97GW0trbOs1TpSg0Dk8DMroyjQDswaxgIIVwB/Bqwq4jn/wDwobzxTuBAiTWmy5YBSVWst7eXn/mZn+Ho0aMcOHCAoaEhDzusYi0tLXR1dbFlyxbWrl1bseBWahjoA66aMa0TGJ9t5uyRBrcAvx9jnPfqPTHGMfLCRk2mVfsMSKpyjY2NbNy4kY0bN6ZdiqpEqYcW3gnszo2EELaT7OPvm2P+rcBzgJtDCKdCCKey0+4LIbyh9HJrgLsJJEk1ptSWgW8AXSGEN8YYP0XSF+COGONUCKEHGIoxTuXNfxCY2dPhWySHJd6zsJKrnLsJJEk1pqQwEGOcDCG8CfhsCOFmIAO8IPtwP0m/gHvy5weeyn+OEMIkcCDGeHrBVVczdxNIkmpMyacwijHeGkK4GLiO5DDDk9npRe3gjzFuK/U1a4otA5KkGrOg8xnGGI8At5W5luXBPgOSpBrjhYrKLdcyMDkKE6Pp1iJJUhEMA+XW2gVk95i4q0CSVAMMA+XW0JDXb+BUqqVIklQMw0Al5MKARxRIkmqAYaAScocXuptAklQDDAOV4G4CSVINMQxUQvvqZDjsNcMlSdXPMFAJHRuS4ekj6dYhSVIRDAOV0JkNA0OGAUlS9TMMVELnBclw6HC6dUiSVATDQCXYMiBJqiGGgUqYbhk4mm4dkiQVwTBQCbmWgbEBGB1MtxZJkuZhGKiEFV3Qvia53/dkurVIkjQPw0Cl5HYVjJxMtw5JkuZhGKiU1s5kODaUbh2SJM3DMFAphgFJUo0wDFSKYUCSVCMMA5ViGJAk1QjDQKWs6EqGYx5aKEmqboaBSpluGTAMSJKqm2GgUlpzLQPuJpAkVTfDQKXYZ0CSVCMMA5ViGJAk1QjDQKUYBiRJNcIwUCn2GZAk1QjDQKXkwoBXLZQkVTnDQKXkH1oYY7q1SJJUgGGgUnJhgAjjw6mWIklSIYaBSmlug9CY3LffgCSpihkGKiUEjyiQJNUEw0AlrfCIAklS9TMMVFKrFyuSJFU/w0AluZtAklQDDAOV5JULJUk1wDBQSbYMSJJqgGGgknJhwLMQSpKqmGGgklo6kuGEJx2SJFUvw0AltaxMhp6BUJJUxQwDldTcngzHR9KtQ5KkAgwDlZRrGXA3gSSpihkGKmm6ZcAwIEmqXoaBSpruM+BuAklS9TIMVJK7CSRJNcAwUEnuJpAk1QDDQCW5m0CSVAMMA5XkbgJJUg0wDFSSuwkkSTXAMFBJLdkwkJmEyfF0a5EkaQ6GgUpqXnn2vrsKJElVyjBQSU0t0NCc3HdXgSSpShkGKq01e+XCsaF065AkaQ6GgUpb0Z0MRwfTrUOSpDkYBiptRU8yHD2VZhWSJM3JMFBp0y0DA+nWIUnSHAwDlZYLA2dOpVqGJElzMQxUWltPMrRlQJJUpQwDlTa9m+BUqmVIkjQXw0Cl2YFQklTlDAOVZp8BSVKVMwxUWltvMrTPgCSpShkGKi0XBkb60q1DkqQ5GAYqrWNdMjx9NN06JEmag2Gg0jrWJ8ORkzA1mW4tkiTNwjBQae2rITQAEUZOpF2NJEnnMQxUWkMjtK9J7p8+lm4tkiTNwjCwFHK7CgwDkqQqZBhYCh1rk6GdCCVJVcgwsBQ6NiTD00fSrUOSpFkYBpZC1wXJcPBwunVIkjQLw8BS6MyGgSHDgCSp+pQUBkIIV4UQ7gwh9IcQbg4hhCKWeW8IoS+EMBZC+OcQQufCy61RXRuToWFAklSFig4DIYRW4AvAXcD1wE7gxnmW+Tng54BXAFcCVwC/t8Baa1dnts+AuwkkSVWolJaBHwe6gbfHGJ8AbgJ+ZZ5ltgC/FGP8XozxceDvgV0LqrSWdWZbBk4fhcxUurVIkjRDUwnzXgPsiTGOZMfvI2kdmFOM8X/OmLQD+OFc82dbH1rzJi2PXQors4cWxqnkUsYrV6dajiRJ+UppGegC9uZGYowRmAoh9BazcAjhMuA/AbcUmO3dwEDe7UAJ9VWvxiZo7U7un/HqhZKk6lJKGJgExmZMGwXa51swhNAAfBL4RIzxwQKzfoBkV0TutrmE+qpbW08yPNOfahmSJM1UShjoA9bOmNYJjBex7P8DrAJ+t9BMMcaxGONg7gYMlVBfdWvLNqAYBiRJVaaUMHAnsDs3EkLYTrJ/v2C7dwjh1cDbgdfl9TeoP+2rkqFhQJJUZUoJA98AukIIb8yO3wTcEWOcCiH0hBAaZy4QQrgC+CzwFmB/CKEjhDDvboVlKdcyMGKfAUlSdSk6DMQYJ4E3AR8LIZwAXgO8K/twP3D1LIv9KrAS+DRJk/8Q8NBiCq5Z7iaQJFWpUg4tJMZ4awjhYuA6ksMMT2anz3omwhjjbwO/vegqlwPDgCSpSpUUBgBijEeA2ypQy/I2HQbcTSBJqi5eqGiptGU7ENpnQJJUZQwDS2X6aALDgCSpuhgGlkp79hTEI/YZkCRVF8PAUpk+tPBkunVIkjSDYWCp5HYTTAzD5MyzOkuSlB7DwFJp7YbceZnsRChJqiKGgaXS0OCuAklSVTIMLKVcJ0KPKJAkVRHDwFJauSYZnj6Wbh2SJOUxDCylzguS4dDhdOuQJCmPYWApdW9KhgMH061DkqQ8hoGl1LU5GQ4eSLcOSZLyGAaWUtfGZDh4KN06JEnKYxhYSu4mkCRVIcPAUuq5MBmePgJjp9OtRZKkLMPAUmpfBSvXJfdPPJZuLZIkZRkGltraHcnw+KPp1iFJUpZhYKlNh4GH061DkqQsw8BS2/CMZHjgrnTrkCQpyzCw1LbuToYHvw+T4+nWIkkShoGlt+bS5IJFk6Nw6AdpVyNJkmFgyYUA256T3N/7H+nWIkkShoF0XPTCZPjE19KtQ5IkDAPpuOj5yfDAnTA+km4tkqS6ZxhIQ+926NoEmQnY/920q5Ek1TnDQBpCgG3PTe4/+fVUS5EkyTCQlktenAwf/7d065Ak1T3DQFoufhEQ4Oj9MHg47WokSXXMMJCWlWtg467kvocYSpJSZBhIU+5shAe+n24dkqS6ZhhI0+brkuFBw4AkKT2GgTRtviEZHr4PRvrSrUWSVLcMA2nq2QLrr4Y4BY/clnY1kqQ6ZRhI25WvSYZ3fgJiTLcWSVJdMgyk7bo3QnM7HL4H7v707PPECEcegOOPLmlpkqT60JR2AXVv5Rp4/jvhjj+AL/42nPgh7HwttLTDqf2w71vw8Behfy80NMEbvwxbnjX388UIQ4fh2ENJeNhwNWx/3lKtjSSpBoVYxU3TIYQuYGBgYICurq60y6mcTAa+8Fb4wf87/7ztq+Hl709OZxynoP8pOPZI9sc/OxwdODt/Uxu8ay80t1WsfElS9RkcHKS7uxugO8Y4WGhew0C1iBEe+0qyq+DQD2ByDDrWw+br4eIXwpYfgb97Q7I7YT6hEVZfDKeehslR+K/fhAueUfFVkCRVj1LCgLsJqkUIsOMVyW0ub/wSfOdjcN/noO9JCA3JEQlrdsC6K2DdzmS45lJoaoVPvgKe/k6yu8AwIEmag2GglrS0J/0Lnv/OpCUhRmgo0Ad07Y5sGHhk6WqUJNUcw0CtCiG5FbL2imR45P7K1yNJqlkeWricbX12Mnx6D2Sm0q1FklS1DAPL2fqroaUTxgZsHZAkzckwsJw1Np09x8DDt6ZbiySpahkGlrurX5cM7/17mBxPtxZJUlUyDCx3O34iOV/B4AH4/l+lXY0kqQoZBpa75jZ4/ruS+199LzzweS+IJEk6h2cgrAcxwj/8Ejz0f5PxdTth87OSFoMQkrMUjpyE4ZPQ2gmv+lAylCTVLM9AqHOFAK/7JKx6H3zn48n1C449NPf86y6H575j6eqTJKXKloF6M9IHT/w7nHwcTh9LgkJjC7SvgtPH4Xt/kVwM6TfuhJWr065WkrRAtgxobu2r4OrXz/7Y1AQ89c2k1eBf/xv8pz9f2tokSamwA6HOamyGV34ouf/gP8PY6XTrkSQtCcOAzrX1R2DVRUmnQk9UJEl1wTCgc4UAu34+uf+t/w1Tk+nWI0mqOMOAzvesN0NbL5x4DO7+P2lXI0mqMMOAzreiC15wU3L/Wx9OtxZJUsUZBjS7Z74BQgMMPA0DB9OuRpJUQYYBza61A9Zfldw/8L10a5EkVZRhQHPb+iPJ8PF/S7cOSVJFGQY0tyt+Mhk+fCtMjKZbiySpYgwDmtuFPwbdW2B0AL7/ybSrkSRViGFAc2togOf9TnL/3/8HPL0n3XokSRXhhYpUWGYKPvP65OJGAOuuhFXboaUDGpqgeQX86Fugd1uqZUqSzlXKhYoMA5rf+DDc+lZ44B9nf3zD1XDj7cn5CSRJVcEwoMoYOgqH7oaBAzBxBkZPJacsjhlYcxk873dh23Ohc0NyWuOZYoTMZLLsxBmYGEmGk2eAABc8M9k1IUlaNMOAls6he+Az/xmGj52d1tgKLe3JME7B5DhMjcHkGFDg8/aCd8MLfq/SFUtSXTAMaGmd6Yc9f54cgnj80SQAzCtAczs0t0FjCwwdSsLDL38JNl1X8ZIlabkzDCg9k+MwdDi5BPLkWNLJsLEFmlqSH/vG5iQENLWe3ZUQI/ztT8MP/xWaVsANvwqXvQJWXwLtq5JlZhNjsosiNyRvvLlt9l0VklRuMSadrafGs7eJ8+9nJs6dvuqiine8Ngyo9owOwj/cCE/McrbDMKMfQYwU3N0AyVEPr/1T2PjMMhUoqeIyU8kPZu6HMzOZ92M6mffYZN48M8cnZ/wAz/LY9A/zxNw/3nPen+Px+b6TZgqN8GO/Bdf+AvRur8g/L4YB1aZMBh77Mtz7WTh4d7LrIGYW/nyhAS55KWz7MejeDG3ZVobQkGyIpWx8LR2w9vKl7+CY+4+jsWlpX1e1Iff5yEzm3WaOZ2/TP16z/ac6waw/wrOOT577Yzrfj3Ipy5T6g1rNGpqTVtHG3DDv/vgwDB44O29LJ3SshRU9ybldLn9lWUooJQz4DaPq0dAAl/9EcoMkHIyeyn5J5MRsS0HI/qhnf9Bz90MDjJ2Gr7wbHvxn+OFXkls5rOhJmvZWrj2726OhMfnyjVPZL+G8+/nTMhN5X8r5X9ITZ7+8p/Lu5+aPmWRdr3odvOpDsKK7POtSa3I/ejP/tjEz99/8vHmnks/UrNPn+0GdmuV+KctMnn0/S1pm5viM5Yvqn1PjGpqTH9GG5iQUT483FZ7e2DJjnpnLzPIjXfL9Ao8X+mcjxqSP1Xdvgf3fhfEh6BtKHhst+JtdMSW3DIQQrgI+BVwCfAJ4Z5znSUIIrwf+F9AMvCPG+NkiX8uWAS3c8UfhkdvgyH1w+ljS0TH35R8zZ39oizF8HMZPV7TcebV2w/bnwppLs4GkFZrasuGoQP+J3Lrmfljm+yEt6Qd2tnln/uBOFvgRnm36LLUup/8Yl0Ku9auhKfsD2HT2h6oh735u+swf1EI/oPP9CE//RzzXY4WeI3+8JQnby73vz+QY9O+DkRPJqd83XJ20ZJZBxXYThBBagUeArwA3Ax8B/jHG+KkCy1wF3AX8BvBd4PPAq2KMjxbxeoYBVYepCTj2UHKOhZGTSdPq5HjyYxUas19aDdlh47nD6S/k5rP3z5mWmyf3JZk/TzMc/gF8+d1w4rG0/wpVLJz/N5/1/Wg4//0JjTP+5k3nvm8zx2e+r7POP89znDfeWOQyRcyTW0/VvUqGgdcCnwQ2xxhHQgjXAB+PMT6nwDJ/AlweY3xFdvy3gLUxxt+fZd5WoDVvUidwwDCgupeZggPfhwN3wql9cOZUcrKmidHkP+ncLpL83Se5jpchzPiRLPCjOOf0UuadOb1pAa+X/59t/mNzBK7l/t+jtACV7DNwDbAnxjiSHb8P2FnEMl/KG/8e8J455n038N4Sa5KWv4ZG2Prs5CZJZVZqW1IXsDc3ku0rMBVC6C12GWAQ2DjHvB8AuvNu5dlxIkmS5lRqy8AkMDZj2ijQDvQXuUxu/vPEGMfy5w02/UmSVHGltgz0AWtnTOsExktYZr75JUnSEio1DNwJ7M6NhBC2k3T46yt2GWAXcLDE15UkSRVSahj4BtAVQnhjdvwm4I4Y41QIoSeE0DjLMv8E/GwI4eoQQgfwVpJDEyVJUhUoKQzEGCeBNwEfCyGcAF4DvCv7cD9w9SzL3At8GPg+SYvAFPCni6hZkiSV0YKuTRBC2ABcR3KY4ckil9kJbAL+I8ZYVJ8BTzokSdLCVPzaBDHGI8BtJS7zEPDQQl5PkiRVjueslCSpzhkGJEmqc4YBSZLqnGFAkqQ6ZxiQJKnOGQYkSapzhgFJkuqcYUCSpDpnGJAkqc4t6AyES21wsOBZFCVJ0gyl/HYu6NoESyWEsAk4kHYdkiTVsM0xxoOFZqj2MBCAjcBQmZ+6kyRkbK7Ac6dhua0PuE61wnWqDa5TbajEOnUCh+I8P/ZVvZsgW3zBNLMQScYAYGi+KznVguW2PuA61QrXqTa4TrWhQutU1PPYgVCSpDpnGJAkqc7VaxgYA/57drgcLLf1AdepVrhOtcF1qg2prVNVdyCUJEmVV68tA5IkKcswIElSnTMMSJJU5wwDkiTVuboLAyGEq0IId4YQ+kMIN4e8szxUqxDCa0IIT4YQJkMI94QQrshO/0gIIebdHs9bpqrXc67aC9UdQnh+COHhEMKJEMLb06v+fCGEG2esT+52Ywjh1hnT7shbrurWKYSwJoSwN4SwLW/agt6XEMLrQwj7QgiHQgj/ZQlX4xxzrNOs21X2sarftuZYpwXVXS2fw5nrVGi7yj5e1dtWge/u6tueYox1cwNagb3AnwMXA7cBb0y7rnlqvhjoA34aWA98Dvj/so99G/gJoCd766yV9Zyt9kJ1A2uBAeA9wKXAXcAL016PvPVpyVuXHpLTiR7Prsch4Kq8x1ZW6zoBa4A9QAS2zfd5KrQO2XUeA94EXA38ENhRJes053Y11+dzvr9F2uu00Lqr5XM4x/s053aVfbxqt625PmPVuj0t6Zud9g14bfbNac+OXwN8K+265qn5VcCv5o2/EBghOZX0ANBRa+s5V+2F6gbeBjzM2cNhXwP8TdrrUmAdbwJuATYBh+eYp+rWCbgDeOuML+QFvS/AnwBfznvu3wL+qErWadbtqtDnc76/RRWs04LqrpbP4WzrNMs8NwG3ZO9X9bY112esWrenettNcA2wJ8Y4kh2/D9iZYj3zijF+McZ4S96kHSSJ8GqS3Tz3hBDOhBC+HELYmp2n2tdzrtoL1X0N8LWY3QqA7wHXLWXRxQohrCDZUN8P3AA0hhAOhBCGQwh/F0Lozc5ajev05hjjR2ZMW+j7cg3w73nPk9b6nbdOBbYrqI1ta7b3aaF1V8vncLZ1mjZju4Iq37YKfMaqcnuqtzDQRdI8A0xfCGkq7wNU1UIILcA7SJqXdgKPAr8APAOYJPlPFKp/PeeqvVDd5zxGcvGNjUtVcIneAHw3xvgUcDlwL/BK4EeA7cAHsvNV3TrFGPfOMnmh70tVrN8c6zRtxnYFNbBtzbFOC627Jt4nzt2uoIa2rRmfsarcnqr6qoUVMMn5p3kcBdqB/qUvp2T/HRgGPhFjnAA+k3sghPDrwN4QQhdVvp4xxs8wS+0kzWNz1T1znXLTq9GvAX8AEGP8AGe/oAgh/C7w+ew8tbJOhT5PhdahVtZveruCuT+f1b5tLaLuWnmfprcrqLltK/8z9kdU4fZUby0DfSQdNPJ1AuMp1FKSEMKLgN8A3pANAjMdI3k/L6D21jNX+xHmrnvmOlXl+oQQLgEuAb46xyzHgNUhhFZqZJ0o/HkqtA5Vv35FbFdQu9tWsXXXwvs033YFVbptzfIZq8rtqd7CwJ3A7txICGE7Sc/OvtQqKkK2zs8CvxFjfCg77eYQwhvyZtsNZID9VPl6Fqj9fuau+5x1AnYBBytfbcl+Gvhi7oclhPD3IYTn5D2+GzgaYxyjdtap0Oep0DpU9frNtl1lp9fktrWIuqv6fco6Z7uC2ti25viMVef2tJS9K9O+kewWOcbZwzj+EvhC2nXNU3Mb8CDJvr+OvNsvAE8CLwZeRrKv8FO1sJ7Az89We6G6SQ47OgO8BGgGvgR8NO11mWXdvgH8ct7475NsxM8h6UV8BHhvta8T5/dSL/l9IenwdJqkY1sH8APgHVWyTnNtV2Guz+d8f4sqWKcF1V1tn0NmOZpg5naVnVbV21aBz1hzNW5PqbzZad6AnyTZd3Mi+4bsTLumeep9TXbjmHnbRrK/7BRwEvgw2WNsa2E956q9UN0k+wJzTWlPAuvTXo8Z69RGsk/v8rxpzcBfZTfiwyTHDzdV+zrN/EJe6PsCvC/7NxkAvg+0VcM6FdquCn0+5/tbVMH7tKC6q+lzOMs6nbddZadX9bZV6DNWjdtTXV7COISwgeSQjD0xxpNp11MptbqeherONqldDnwzxng6jfrKrVbWaaHvSwhhJ8kx4f8RY6yqfdEL5bZVG6p5napte6rLMCBJks6qtw6EkiRpBsOAJEl1zjAgSVKdMwxIklTnDAOSJNU5w4AkSXXOMCBJUp0zDEiSVOcMA5Ik1bn/H1pmP44fXr+HAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Reinforce\n", "model = PolicyNet()\n", "optimizer = optim.AdamW(model.parameters(), lr=learning_rate)\n", "v = []\n", "\n", "for t in range(2000):\n", " states, actions, rewards = play_game(model, game)\n", " # 将一次游玩看成是G次游玩\n", " cum_rewards = get_cum_rewards(rewards, gamma)\n", " cum_rewards = torch.tensor(cum_rewards) # (G)\n", " actions = torch.concat(actions).squeeze(-1) # (G)\n", " states = torch.tensor([tokenizer[s] for s in states]) # (G)\n", " optimizer.zero_grad()\n", " logits = model(states) # (G, 2)\n", " lnP = -F.cross_entropy(logits, actions, reduction='none') # (G)\n", " # 定义模型损失\n", " loss = -cum_rewards * lnP # (G)\n", " loss.mean().backward()\n", " optimizer.step()\n", " # 记录每个状态下,模型预估的每个行动的概率\n", " eval_re = {}\n", " for k in tokenizer:\n", " _re = F.softmax(model(torch.tensor([tokenizer[k]])), dim=-1) # (1, 2)\n", " eval_re[k] = _re.squeeze(0).tolist()\n", " v.append(eval_re)\n", " \n", "fig = plot_action_probs(v)\n", "fig.savefig('policy_learning.png', dpi=200)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }