{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from utils import Lottery, plot_values, plot_action_probs\n", "\n", "\n", "torch.manual_seed(12046)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[tensor(10.0095),\n", " tensor(9.0089),\n", " tensor(8.0156),\n", " tensor(7.0169),\n", " tensor(6.0171),\n", " tensor(4.9980),\n", " tensor(4.0106),\n", " tensor(3.0101),\n", " tensor(2.0078),\n", " tensor(0.9982)]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def get_cum_rewards(r, gamma):\n", " '''\n", " 计算每一步的游戏得分并返回\n", " '''\n", " cum_rewards = []\n", " last_cum_reward = 0\n", " for j in reversed(r):\n", " last_cum_reward = j + gamma * last_cum_reward\n", " cum_rewards.insert(0, last_cum_reward)\n", " return cum_rewards\n", "\n", "get_cum_rewards(torch.normal(1, 0.01, (10,)), 1)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# 一些超参数\n", "gamma = 0.9\n", "learning_rate = 0.01\n", "grad_clip = 1.0" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class PolicyNet(nn.Module):\n", " \n", " def __init__(self):\n", " '''\n", " 策略学习\n", " '''\n", " super().__init__()\n", " self.emb = nn.Embedding(2, 4)\n", " self.ln = nn.Linear(4, 2)\n", "\n", " def forward(self, x):\n", " '''\n", " 向前传播\n", " 参数\n", " ----\n", " x :torch.LongTensor,游戏状态,形状为(G),其中G表示游戏步数\n", " 返回\n", " ----\n", " out :torch.FloatTensor,logits,形状为(G, 2)\n", " '''\n", " x = F.relu(self.emb(x))\n", " out = self.ln(x)\n", " return out\n", "\n", "# 定义游戏状态的数字表示\n", "tokenizer = {'w': 0, 'l': 1}" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "@torch.no_grad()\n", "def play_game(model, game):\n", " s = game.reset()\n", " done = False\n", " one_game_state = []\n", " one_game_reward = []\n", " one_game_action = []\n", " while not done:\n", " x = torch.tensor([tokenizer[s]]) # (1)\n", " logits = model(x) # (1, 2)\n", " probs = F.softmax(logits, dim=-1) # (1, 2)\n", " # 利用神经网络得到下一个行动\n", " action = torch.multinomial(probs, 1)\n", " next_s, r = game.step(action)\n", " # 记录游戏过程,分别是行动、状态和奖励\n", " one_game_action.append(action)\n", " one_game_state.append(s)\n", " one_game_reward.append(r)\n", " s = next_s\n", " if next_s == 't':\n", " done = True\n", " return one_game_state, one_game_action, one_game_reward" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(['l'], [tensor([[0]])], [0])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = PolicyNet()\n", "game = Lottery()\n", "play_game(model, game)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Reinforce\n", "model = PolicyNet()\n", "optimizer = optim.AdamW(model.parameters(), lr=learning_rate)\n", "v = []\n", "\n", "for t in range(2000):\n", " states, actions, rewards = play_game(model, game)\n", " # 将一次游玩看成是G次游玩\n", " cum_rewards = get_cum_rewards(rewards, gamma)\n", " cum_rewards = torch.tensor(cum_rewards) # (G)\n", " actions = torch.concat(actions).squeeze(-1) # (G)\n", " states = torch.tensor([tokenizer[s] for s in states]) # (G)\n", " optimizer.zero_grad()\n", " logits = model(states) # (G, 2)\n", " lnP = -F.cross_entropy(logits, actions, reduction='none') # (G)\n", " # 定义模型损失\n", " loss = -cum_rewards * lnP # (G)\n", " loss.mean().backward()\n", " optimizer.step()\n", " # 记录每个状态下,模型预估的每个行动的概率\n", " eval_re = {}\n", " for k in tokenizer:\n", " _re = F.softmax(model(torch.tensor([tokenizer[k]])), dim=-1) # (1, 2)\n", " eval_re[k] = _re.squeeze(0).tolist()\n", " v.append(eval_re)\n", " \n", "fig = plot_action_probs(v)\n", "fig.savefig('policy_learning.png', dpi=200)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }