Gen TANG 2 роки тому
батько
коміт
5885b61320

+ 0 - 2994
ch12_rl/Untitled.ipynb

@@ -1,2994 +0,0 @@
-{
- "cells": [
-  {
-   "cell_type": "code",
-   "execution_count": 115,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n",
-      "\u001b[0mCollecting trl\n",
-      "  Using cached trl-0.7.4-py3-none-any.whl.metadata (10 kB)\n",
-      "Requirement already satisfied: torch>=1.4.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (2.0.1)\n",
-      "Requirement already satisfied: transformers>=4.18.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (4.31.0)\n",
-      "Requirement already satisfied: numpy>=1.18.2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (1.24.4)\n",
-      "Requirement already satisfied: accelerate in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (0.21.0)\n",
-      "Requirement already satisfied: datasets in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (2.14.2)\n",
-      "Requirement already satisfied: tyro>=0.5.11 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (0.5.12)\n",
-      "Requirement already satisfied: filelock in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch>=1.4.0->trl) (3.0.12)\n",
-      "Requirement already satisfied: typing-extensions in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch>=1.4.0->trl) (4.8.0)\n",
-      "Requirement already satisfied: sympy in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch>=1.4.0->trl) (1.6.2)\n",
-      "Requirement already satisfied: networkx in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch>=1.4.0->trl) (2.5)\n",
-      "Requirement already satisfied: jinja2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch>=1.4.0->trl) (2.11.2)\n",
-      "Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (0.16.4)\n",
-      "Requirement already satisfied: packaging>=20.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (23.1)\n",
-      "Requirement already satisfied: pyyaml>=5.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (5.3.1)\n",
-      "Requirement already satisfied: regex!=2019.12.17 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (2020.10.15)\n",
-      "Requirement already satisfied: requests in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (2.24.0)\n",
-      "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (0.13.3)\n",
-      "Requirement already satisfied: safetensors>=0.3.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (0.3.1)\n",
-      "Requirement already satisfied: tqdm>=4.27 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (4.65.0)\n",
-      "Requirement already satisfied: docstring-parser>=0.14.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from tyro>=0.5.11->trl) (0.15)\n",
-      "Requirement already satisfied: rich>=11.1.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from tyro>=0.5.11->trl) (13.6.0)\n",
-      "Requirement already satisfied: shtab>=1.5.6 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from tyro>=0.5.11->trl) (1.6.4)\n",
-      "Requirement already satisfied: psutil in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from accelerate->trl) (5.7.2)\n",
-      "Requirement already satisfied: pyarrow>=8.0.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (12.0.1)\n",
-      "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (0.3.7)\n",
-      "Requirement already satisfied: pandas in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (2.0.3)\n",
-      "Requirement already satisfied: xxhash in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (3.3.0)\n",
-      "Requirement already satisfied: multiprocess in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (0.70.15)\n",
-      "Requirement already satisfied: fsspec>=2021.11.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from fsspec[http]>=2021.11.1->datasets->trl) (2023.6.0)\n",
-      "Requirement already satisfied: aiohttp in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (3.8.5)\n",
-      "Requirement already satisfied: attrs>=17.3.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (23.1.0)\n",
-      "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (3.2.0)\n",
-      "Requirement already satisfied: multidict<7.0,>=4.5 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (6.0.4)\n",
-      "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (4.0.2)\n",
-      "Requirement already satisfied: yarl<2.0,>=1.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (1.9.2)\n",
-      "Requirement already satisfied: frozenlist>=1.1.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (1.4.0)\n",
-      "Requirement already satisfied: aiosignal>=1.1.2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (1.3.1)\n",
-      "Requirement already satisfied: chardet<4,>=3.0.2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers>=4.18.0->trl) (3.0.4)\n",
-      "Requirement already satisfied: idna<3,>=2.5 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers>=4.18.0->trl) (2.10)\n",
-      "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers>=4.18.0->trl) (1.25.11)\n",
-      "Requirement already satisfied: certifi>=2017.4.17 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers>=4.18.0->trl) (2020.6.20)\n",
-      "Requirement already satisfied: markdown-it-py>=2.2.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (3.0.0)\n",
-      "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (2.16.1)\n",
-      "Requirement already satisfied: MarkupSafe>=0.23 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from jinja2->torch>=1.4.0->trl) (1.1.1)\n",
-      "Requirement already satisfied: decorator>=4.3.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from networkx->torch>=1.4.0->trl) (4.4.2)\n",
-      "Requirement already satisfied: python-dateutil>=2.8.2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from pandas->datasets->trl) (2.8.2)\n",
-      "Requirement already satisfied: pytz>=2020.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from pandas->datasets->trl) (2020.1)\n",
-      "Requirement already satisfied: tzdata>=2022.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from pandas->datasets->trl) (2023.3)\n",
-      "Requirement already satisfied: mpmath>=0.19 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from sympy->torch>=1.4.0->trl) (1.1.0)\n",
-      "Requirement already satisfied: mdurl~=0.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro>=0.5.11->trl) (0.1.2)\n",
-      "Requirement already satisfied: six>=1.5 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from python-dateutil>=2.8.2->pandas->datasets->trl) (1.15.0)\n",
-      "Using cached trl-0.7.4-py3-none-any.whl (133 kB)\n",
-      "\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n",
-      "\u001b[0m\u001b[33mDEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063\u001b[0m\u001b[33m\n",
-      "\u001b[0mInstalling collected packages: trl\n"
-     ]
-    },
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Successfully installed trl-0.7.4\r\n"
-     ]
-    }
-   ],
-   "source": [
-    "!pip install --user trl"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 5,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "import torch\n",
-    "import torch.nn.functional as F"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 7,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "a = torch.randn(1, 4, requires_grad=True)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 8,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[ 0.2511, -0.2847, -1.3987,  0.9078]], requires_grad=True)"
-      ]
-     },
-     "execution_count": 8,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "a"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 9,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "s = torch.argmax(a)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 10,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor(3)"
-      ]
-     },
-     "execution_count": 10,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "s"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 19,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([ 0., 10.,  3.,  0.], requires_grad=True)"
-      ]
-     },
-     "execution_count": 19,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "weights = torch.tensor([0, 10, 3, 0], dtype=torch.float, requires_grad=True) # create a tensor of weights\n",
-    "weights"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 20,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "s = torch.multinomial(weights, 1)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 21,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([2])"
-      ]
-     },
-     "execution_count": 21,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "s"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 22,
-   "metadata": {},
-   "outputs": [
-    {
-     "ename": "RuntimeError",
-     "evalue": "element 0 of tensors does not require grad and does not have a grad_fn",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
-      "\u001b[0;32m<ipython-input-22-12626f95b38d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0ms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    485\u001b[0m                 \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    486\u001b[0m             )\n\u001b[0;32m--> 487\u001b[0;31m         torch.autograd.backward(\n\u001b[0m\u001b[1;32m    488\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    489\u001b[0m         )\n",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    198\u001b[0m     \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    199\u001b[0m     \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m    201\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    202\u001b[0m         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass\n",
-      "\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn"
-     ]
-    }
-   ],
-   "source": [
-    "s.backward()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 226,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "import torch\n",
-    "from transformers import AutoTokenizer, GPT2LMHeadModel\n",
-    "\n",
-    "\n",
-    "torch.manual_seed(12046)\n",
-    "\n",
-    "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 227,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "50257"
-      ]
-     },
-     "execution_count": 227,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "tokenizer.vocab_size"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 231,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "1.027730370438185e+47"
-      ]
-     },
-     "execution_count": 231,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "import math\n",
-    "\n",
-    "math.pow(50256, 10)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 55,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n",
-      "Input length of input_ids is 7, but `max_length` is set to 1. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.\n"
-     ]
-    }
-   ],
-   "source": [
-    "res = model.generate(**ids, max_length=1)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 56,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[2061,  318,  262, 3139,  286, 2807,   30,  198]])"
-      ]
-     },
-     "execution_count": 56,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "res"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 53,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "l = model(**ids).logits"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 54,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[[ -37.2172,  -36.8864,  -40.3563,  ...,  -43.4351,  -43.0232,\n",
-       "           -37.0993],\n",
-       "         [ -84.7418,  -82.6453,  -88.2857,  ...,  -89.4501,  -90.0555,\n",
-       "           -84.6770],\n",
-       "         [ -94.2372,  -92.2747,  -95.0347,  ...,  -94.9362,  -97.9820,\n",
-       "           -91.8550],\n",
-       "         ...,\n",
-       "         [ -75.3586,  -73.5543,  -77.7737,  ...,  -80.5165,  -81.9881,\n",
-       "           -76.4165],\n",
-       "         [ -77.2812,  -75.8471,  -80.8869,  ...,  -88.2672,  -86.0258,\n",
-       "           -79.5727],\n",
-       "         [-123.9086, -123.2837, -124.5704,  ..., -135.1090, -134.9653,\n",
-       "          -119.7716]]], grad_fn=<UnsafeViewBackward0>)"
-      ]
-     },
-     "execution_count": 54,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "l"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 60,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[[ -37.2172,  -36.8864,  -40.3563,  ...,  -43.4351,  -43.0232,\n",
-       "           -37.0993],\n",
-       "         [ -84.7418,  -82.6453,  -88.2857,  ...,  -89.4501,  -90.0555,\n",
-       "           -84.6770],\n",
-       "         [ -94.2372,  -92.2747,  -95.0347,  ...,  -94.9362,  -97.9820,\n",
-       "           -91.8550],\n",
-       "         ...,\n",
-       "         [ -75.3586,  -73.5543,  -77.7737,  ...,  -80.5165,  -81.9881,\n",
-       "           -76.4165],\n",
-       "         [ -77.2812,  -75.8471,  -80.8869,  ...,  -88.2672,  -86.0258,\n",
-       "           -79.5727],\n",
-       "         [-123.9086, -123.2837, -124.5704,  ..., -135.1090, -134.9653,\n",
-       "          -119.7716]]])"
-      ]
-     },
-     "execution_count": 60,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "with torch.no_grad():\n",
-    "    ll = model(**ids).logits\n",
-    "ll"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 87,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "weights = torch.tensor([0, 10, 3, 0], dtype=torch.float, requires_grad=True)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 88,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([1])"
-      ]
-     },
-     "execution_count": 88,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "y = torch.multinomial(weights, 1)\n",
-    "y"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 89,
-   "metadata": {},
-   "outputs": [
-    {
-     "ename": "RuntimeError",
-     "evalue": "element 0 of tensors does not require grad and does not have a grad_fn",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
-      "\u001b[0;32m<ipython-input-89-ab75bb780f4c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    485\u001b[0m                 \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    486\u001b[0m             )\n\u001b[0;32m--> 487\u001b[0;31m         torch.autograd.backward(\n\u001b[0m\u001b[1;32m    488\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    489\u001b[0m         )\n",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    198\u001b[0m     \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    199\u001b[0m     \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m    201\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    202\u001b[0m         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass\n",
-      "\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn"
-     ]
-    }
-   ],
-   "source": [
-    "y.backward()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 91,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "<ipython-input-91-26e670d9452d>:1: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
-      "  F.softmax(weights)[y]\n"
-     ]
-    },
-    {
-     "data": {
-      "text/plain": [
-       "tensor([0.9990], grad_fn=<IndexBackward0>)"
-      ]
-     },
-     "execution_count": 91,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "F.softmax(weights)[y]"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 82,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([2])"
-      ]
-     },
-     "execution_count": 82,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "y"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 86,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "k = (weights == weights[y]).nonzero()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 51,
-   "metadata": {
-    "scrolled": true
-   },
-   "outputs": [
-    {
-     "ename": "RuntimeError",
-     "evalue": "only Tensors of floating point and complex dtype can require gradients",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
-      "\u001b[0;32m<ipython-input-51-e107420eacad>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mids\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'input_ids'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_grad\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
-      "\u001b[0;31mRuntimeError\u001b[0m: only Tensors of floating point and complex dtype can require gradients"
-     ]
-    }
-   ],
-   "source": [
-    "ids['input_ids'].requires_grad=True"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 49,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "torch.Tensor"
-      ]
-     },
-     "execution_count": 49,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": []
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 50,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "torch.Tensor"
-      ]
-     },
-     "execution_count": 50,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "type(a)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 57,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor(1.0545, grad_fn=<SumBackward0>)"
-      ]
-     },
-     "execution_count": 57,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "a = torch.rand(1, 4)\n",
-    "a.requires_grad=True\n",
-    "y = torch.sum(a)\n",
-    "y"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 46,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "torch.sum(a).backward()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 37,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor(8903)"
-      ]
-     },
-     "execution_count": 37,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "y = torch.sum(ids['input_ids'])\n",
-    "y"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 38,
-   "metadata": {},
-   "outputs": [
-    {
-     "ename": "RuntimeError",
-     "evalue": "element 0 of tensors does not require grad and does not have a grad_fn",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
-      "\u001b[0;32m<ipython-input-38-ab75bb780f4c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    485\u001b[0m                 \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    486\u001b[0m             )\n\u001b[0;32m--> 487\u001b[0;31m         torch.autograd.backward(\n\u001b[0m\u001b[1;32m    488\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    489\u001b[0m         )\n",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    198\u001b[0m     \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    199\u001b[0m     \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m    201\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    202\u001b[0m         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass\n",
-      "\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn"
-     ]
-    }
-   ],
-   "source": [
-    "y.backward()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 94,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "x = torch.randn(3, 2, requires_grad=True)\n",
-    "y = torch.ones(3, 2)\n",
-    "x\n",
-    "z = torch.where(x > 0, 1.0, 0.0)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 95,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[0., 1.],\n",
-       "        [1., 1.],\n",
-       "        [0., 0.]])"
-      ]
-     },
-     "execution_count": 95,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "z"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 103,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "t = (x - 0 > 0.001) * 1 + 0"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 100,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[-1.4747,  0.2449],\n",
-       "        [ 0.8045,  0.4111],\n",
-       "        [-0.7769, -0.4431]], requires_grad=True)"
-      ]
-     },
-     "execution_count": 100,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "x"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 104,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[0, 1],\n",
-       "        [1, 1],\n",
-       "        [0, 0]])"
-      ]
-     },
-     "execution_count": 104,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "t"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 106,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[False,  True],\n",
-       "        [ True,  True],\n",
-       "        [False, False]])"
-      ]
-     },
-     "execution_count": 106,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "(x - 0 > 0.001)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 108,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[-1.4747,  0.2449],\n",
-       "        [ 0.8045,  0.4111],\n",
-       "        [-0.7769, -0.4431]], requires_grad=True)"
-      ]
-     },
-     "execution_count": 108,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "x"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 109,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor(0.8045, grad_fn=<MaxBackward1>)"
-      ]
-     },
-     "execution_count": 109,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "torch.max(x)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 117,
-   "metadata": {},
-   "outputs": [
-    {
-     "ename": "ModuleNotFoundError",
-     "evalue": "No module named 'trl'",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
-      "\u001b[0;32m<ipython-input-117-8b0f0cf4af74>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtrl\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mAutoModelForCausalLMWithValueHea\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
-      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'trl'"
-     ]
-    }
-   ],
-   "source": [
-    "from trl import AutoModelForCausalLMWithValueHea"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "/Users/tgbaggio/.local/lib/python3.8/site-packages/trl/trainer/ppo_config.py:141: UserWarning: The `optimize_cuda_cache` arguement will be deprecated soon, please use `optimize_device_cache` instead.\n",
-      "  warnings.warn(\n"
-     ]
-    }
-   ],
-   "source": [
-    "import trl"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 2,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "from trl import AutoModelForCausalLMWithValueHead"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 4,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "c8eb15f8baa64dbfa6a41e4349bc1068",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Downloading (…)lve/main/config.json:   0%|          | 0.00/577 [00:00<?, ?B/s]"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "c92409e5bc944693b5578571245093bf",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Downloading pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    }
-   ],
-   "source": [
-    "model = AutoModelForCausalLMWithValueHead.from_pretrained('lvwerra/gpt2-imdb')"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 5,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "AutoModelForCausalLMWithValueHead(\n",
-       "  (pretrained_model): GPT2LMHeadModel(\n",
-       "    (transformer): GPT2Model(\n",
-       "      (wte): Embedding(50257, 768)\n",
-       "      (wpe): Embedding(1024, 768)\n",
-       "      (drop): Dropout(p=0.1, inplace=False)\n",
-       "      (h): ModuleList(\n",
-       "        (0-11): 12 x GPT2Block(\n",
-       "          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
-       "          (attn): GPT2Attention(\n",
-       "            (c_attn): Conv1D()\n",
-       "            (c_proj): Conv1D()\n",
-       "            (attn_dropout): Dropout(p=0.1, inplace=False)\n",
-       "            (resid_dropout): Dropout(p=0.1, inplace=False)\n",
-       "          )\n",
-       "          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
-       "          (mlp): GPT2MLP(\n",
-       "            (c_fc): Conv1D()\n",
-       "            (c_proj): Conv1D()\n",
-       "            (act): NewGELUActivation()\n",
-       "            (dropout): Dropout(p=0.1, inplace=False)\n",
-       "          )\n",
-       "        )\n",
-       "      )\n",
-       "      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
-       "    )\n",
-       "    (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
-       "  )\n",
-       "  (v_head): ValueHead(\n",
-       "    (dropout): Dropout(p=0.1, inplace=False)\n",
-       "    (summary): Linear(in_features=768, out_features=1, bias=True)\n",
-       "    (flatten): Flatten(start_dim=1, end_dim=-1)\n",
-       "  )\n",
-       ")"
-      ]
-     },
-     "execution_count": 5,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "model"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 10,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "l = model(**ids)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 13,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "torch.Size([1, 7, 50257])"
-      ]
-     },
-     "execution_count": 13,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "l[0].shape"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 15,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "torch.Size([1, 7])"
-      ]
-     },
-     "execution_count": 15,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "l[2].shape"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 16,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "from transformers import AutoTokenizer, pipeline"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 17,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "a14f54b9e20d4fa2bf97d5d707531fcf",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Downloading (…)lve/main/config.json:   0%|          | 0.00/735 [00:00<?, ?B/s]"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "1397323a065743a19dc8e235b0a49018",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Downloading pytorch_model.bin:   0%|          | 0.00/268M [00:00<?, ?B/s]"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "898a7c9e19d549459bd8a92006ae3f2b",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Downloading (…)okenizer_config.json:   0%|          | 0.00/333 [00:00<?, ?B/s]"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "123feae78d4f4bc182d7405985f33462",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "fbf7ebeb54244310bfe2d225e1441db4",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "7170262df42543be877e56798b06a64a",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\n",
-      "pip install xformers.\n"
-     ]
-    }
-   ],
-   "source": [
-    "p = pipeline('sentiment-analysis', 'lvwerra/distilbert-imdb')"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 21,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "sent_kwargs = {\"return_all_scores\": True, \"function_to_apply\": \"none\", \"batch_size\": 16}"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 29,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "{'input_ids': tensor([[2061,  318,  262, 3139,  286, 2807,   30]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}"
-      ]
-     },
-     "execution_count": 29,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "question = 'What is the capital of China?'\n",
-    "ids = tokenizer(question, return_tensors=\"pt\")\n",
-    "ids"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 30,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "pipe_outputs = p(question, **sent_kwargs)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 31,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "[[{'label': 'NEGATIVE', 'score': 0.28462526202201843},\n",
-       "  {'label': 'POSITIVE', 'score': -0.498414009809494}]]"
-      ]
-     },
-     "execution_count": 31,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "pipe_outputs"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 32,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "score = torch.randn(3)\n",
-    "log = torch.randn(3, 7)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 45,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "tensor(2.3532) torch.Size([7])\n",
-      "tensor(-0.3089) torch.Size([7])\n",
-      "tensor(-1.0061) torch.Size([7])\n"
-     ]
-    }
-   ],
-   "source": [
-    "for s, l in zip(score, log):\n",
-    "    print(s, l.shape)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 46,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "x = torch.randn(4)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 47,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([ 0.8624,  0.1586, -1.3384,  0.4741])"
-      ]
-     },
-     "execution_count": 47,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "x"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 48,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "x += 1"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 49,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([ 1.8624,  1.1586, -0.3384,  1.4741])"
-      ]
-     },
-     "execution_count": 49,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "x"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 51,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "x[2] += 1"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 55,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[-0.0498,  0.3148, -0.6248,  0.0660],\n",
-       "        [ 1.1840, -1.0514, -0.8765, -0.0313],\n",
-       "        [ 2.3417, -0.0704, -0.0996,  1.7966]])"
-      ]
-     },
-     "execution_count": 55,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "x = torch.randn(3, 4)\n",
-    "x"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 56,
-   "metadata": {},
-   "outputs": [
-    {
-     "ename": "ValueError",
-     "evalue": "step must be greater than zero",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
-      "\u001b[0;32m<ipython-input-56-fc433e12835a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
-      "\u001b[0;31mValueError\u001b[0m: step must be greater than zero"
-     ]
-    }
-   ],
-   "source": [
-    "x[::-1]"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 79,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "self_config_gamma = 0.8\n",
-    "self_config_lam = 0.9\n",
-    "\n",
-    "def compute_advantages(\n",
-    "    values: torch.FloatTensor,\n",
-    "    rewards: torch.FloatTensor):\n",
-    "    lastgaelam = 0\n",
-    "    advantages_reversed = []\n",
-    "    gen_len = rewards.shape[-1]\n",
-    "    for t in reversed(range(gen_len)):\n",
-    "        nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0\n",
-    "        delta = rewards[:, t] + self_config_gamma * nextvalues - values[:, t]\n",
-    "        lastgaelam = delta + self_config_gamma * self_config_lam * lastgaelam\n",
-    "        advantages_reversed.append(lastgaelam)\n",
-    "    advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)\n",
-    "\n",
-    "    returns = advantages + values\n",
-    "    advantages = advantages.detach()\n",
-    "    return values, advantages, returns\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 80,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "values = torch.tensor([[3., 4., 2.]])\n",
-    "rewards = torch.tensor([[-1., 2., 2.]])"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 81,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "(tensor([[3., 4., 2.]]),\n",
-       " tensor([[-1.0880, -0.4000,  0.0000]]),\n",
-       " tensor([[1.9120, 3.6000, 2.0000]]))"
-      ]
-     },
-     "execution_count": 81,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "compute_advantages(values, rewards)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "1.8800000000000003"
-      ]
-     },
-     "execution_count": 1,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "0.8 * 2 + 0.8 ** 2 * 2 - 1"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 5,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "import gym\n",
-    "import pygame\n",
-    "\n",
-    "env = gym.make('MountainCar-v0', render_mode=\"human\")"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 6,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "The observation space: Box([-1.2  -0.07], [0.6  0.07], (2,), float32)\n",
-      "The action space: Discrete(3)\n"
-     ]
-    }
-   ],
-   "source": [
-    "# Observation and action space \n",
-    "obs_space = env.observation_space\n",
-    "action_space = env.action_space\n",
-    "print(\"The observation space: {}\".format(obs_space))\n",
-    "print(\"The action space: {}\".format(action_space))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 7,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "The initial observation is (array([-0.46305716,  0.        ], dtype=float32), {})\n",
-      "0\n",
-      "The new observation is [-0.46450874 -0.00145157]\n"
-     ]
-    }
-   ],
-   "source": [
-    "import matplotlib.pyplot as plt \n",
-    "\n",
-    "# reset the environment and see the initial observation\n",
-    "obs = env.reset()\n",
-    "print(\"The initial observation is {}\".format(obs))\n",
-    "\n",
-    "# Sample a random action from the entire action space\n",
-    "random_action = env.action_space.sample()\n",
-    "print(random_action)\n",
-    "\n",
-    "# # Take the action and get the new observation space\n",
-    "new_obs, reward, done, _, info = env.step(random_action)\n",
-    "env.render() \n",
-    "pygame.display.update()\n",
-    "print(\"The new observation is {}\".format(new_obs))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 16,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "2\n",
-      "The new observation is [-5.2235210e-01 -4.3133463e-04]\n"
-     ]
-    }
-   ],
-   "source": [
-    "# Sample a random action from the entire action space\n",
-    "random_action = env.action_space.sample()\n",
-    "print(random_action)\n",
-    "\n",
-    "# # Take the action and get the new observation space\n",
-    "new_obs, reward, done, _, info = env.step(random_action)\n",
-    "print(\"The new observation is {}\".format(new_obs))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 17,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "env.close()\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 19,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "<function pygame.event.get>"
-      ]
-     },
-     "execution_count": 19,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "import pygame\n",
-    "\n",
-    "pygame.event.get"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {},
-   "outputs": [
-    {
-     "ename": "DependencyNotInstalled",
-     "evalue": "Box2D is not installed, run `pip install gymnasium[box2d]`",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/gymnasium/envs/box2d/bipedal_walker.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     14\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m     \u001b[0;32mimport\u001b[0m \u001b[0mBox2D\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     16\u001b[0m     from Box2D.b2 import (\n",
-      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'Box2D'",
-      "\nThe above exception was the direct cause of the following exception:\n",
-      "\u001b[0;31mDependencyNotInstalled\u001b[0m                    Traceback (most recent call last)",
-      "\u001b[0;32m<ipython-input-1-c5175ad63d7c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mgymnasium\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"LunarLander-v2\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrender_mode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"human\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m \u001b[0mobservation\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minfo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/gymnasium/envs/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(id, max_episode_steps, autoreset, apply_api_compatibility, disable_env_checker, **kwargs)\u001b[0m\n\u001b[1;32m    754\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    755\u001b[0m         \u001b[0;31m# Assume it's a string\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 756\u001b[0;31m         \u001b[0menv_creator\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_env_creator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv_spec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mentry_point\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    757\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    758\u001b[0m     \u001b[0;31m# Determine if to use the rendering\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/gymnasium/envs/registration.py\u001b[0m in \u001b[0;36mload_env_creator\u001b[0;34m(name)\u001b[0m\n\u001b[1;32m    543\u001b[0m     \"\"\"\n\u001b[1;32m    544\u001b[0m     \u001b[0mmod_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattr_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\":\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 545\u001b[0;31m     \u001b[0mmod\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimportlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimport_module\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmod_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    546\u001b[0m     \u001b[0mfn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattr_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    547\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/__init__.py\u001b[0m in \u001b[0;36mimport_module\u001b[0;34m(name, package)\u001b[0m\n\u001b[1;32m    125\u001b[0m                 \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    126\u001b[0m             \u001b[0mlevel\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 127\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0m_bootstrap\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_gcd_import\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlevel\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpackage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlevel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    128\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    129\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_gcd_import\u001b[0;34m(name, package, level)\u001b[0m\n",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load\u001b[0;34m(name, import_)\u001b[0m\n",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load_unlocked\u001b[0;34m(name, import_)\u001b[0m\n",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_call_with_frames_removed\u001b[0;34m(f, *args, **kwds)\u001b[0m\n",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_gcd_import\u001b[0;34m(name, package, level)\u001b[0m\n",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load\u001b[0;34m(name, import_)\u001b[0m\n",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load_unlocked\u001b[0;34m(name, import_)\u001b[0m\n",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_load_unlocked\u001b[0;34m(spec)\u001b[0m\n",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap_external.py\u001b[0m in \u001b[0;36mexec_module\u001b[0;34m(self, module)\u001b[0m\n",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_call_with_frames_removed\u001b[0;34m(f, *args, **kwds)\u001b[0m\n",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/gymnasium/envs/box2d/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mgymnasium\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbox2d\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbipedal_walker\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mBipedalWalker\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mBipedalWalkerHardcore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgymnasium\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbox2d\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcar_racing\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCarRacing\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgymnasium\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbox2d\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlunar_lander\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mLunarLander\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLunarLanderContinuous\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/gymnasium/envs/box2d/bipedal_walker.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     23\u001b[0m     )\n\u001b[1;32m     24\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m     raise DependencyNotInstalled(\n\u001b[0m\u001b[1;32m     26\u001b[0m         \u001b[0;34m\"Box2D is not installed, run `pip install gymnasium[box2d]`\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     27\u001b[0m     ) from e\n",
-      "\u001b[0;31mDependencyNotInstalled\u001b[0m: Box2D is not installed, run `pip install gymnasium[box2d]`"
-     ]
-    }
-   ],
-   "source": [
-    "import gymnasium as gym\n",
-    "\n",
-    "env = gym.make(\"LunarLander-v2\", render_mode=\"human\")\n",
-    "observation, info = env.reset()\n",
-    "\n",
-    "for _ in range(1000):\n",
-    "    action = env.action_space.sample()  # agent policy that uses the observation and info\n",
-    "    observation, reward, terminated, truncated, info = env.step(action)\n",
-    "\n",
-    "    if terminated or truncated:\n",
-    "        observation, info = env.reset()\n",
-    "\n",
-    "env.close()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 2,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "<contextlib.ExitStack at 0x7fbffd7b04c0>"
-      ]
-     },
-     "execution_count": 2,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "import gymnasium as gym\n",
-    "import math\n",
-    "import random\n",
-    "import matplotlib\n",
-    "import matplotlib.pyplot as plt\n",
-    "env = gym.make(\"CartPole-v1\")\n",
-    "\n",
-    "# set up matplotlib\n",
-    "is_ipython = 'inline' in matplotlib.get_backend()\n",
-    "if is_ipython:\n",
-    "    from IPython import display\n",
-    "\n",
-    "plt.ion()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 52,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "import gymnasium as gym\n",
-    "import pygame\n",
-    "\n",
-    "env = gym.make(\"CartPole-v1\", render_mode=\"human\")\n",
-    "observation, info = env.reset(seed=42)\n",
-    "#for _ in range(100):\n",
-    "#    action = env.action_space.sample()  # this is where you would insert your policy\n",
-    "#    observation, reward, terminated, truncated, info = env.step(action)\n",
-    "#    if terminated or truncated:\n",
-    "#        break\n",
-    "#env.close()\n",
-    "\n",
-    "#pygame.display.update()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 62,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "(array([ 0.20159529,  1.9464185 , -0.22034578, -2.9908078 ], dtype=float32),\n",
-       " 1.0,\n",
-       " True,\n",
-       " False,\n",
-       " {})"
-      ]
-     },
-     "execution_count": 62,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "action = env.action_space.sample()  # this is where you would insert your policy\n",
-    "env.step(1)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 63,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "env.close()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 64,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "0.075"
-      ]
-     },
-     "execution_count": 64,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "0.3/4"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 163,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "m = [\n",
-    "    [0.775, 0.075, 0.075, 0.075],\n",
-    "    [0.075, 0.775, 0.075, 0.075],\n",
-    "    [0.075, 0.075, 0.775, 0.075],\n",
-    "    [0.075, 0.075, 0.075, 0.775]\n",
-    "]"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 164,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "import torch\n",
-    "m = torch.tensor(m)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 95,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[1.],\n",
-       "        [1.],\n",
-       "        [1.],\n",
-       "        [1.]])"
-      ]
-     },
-     "execution_count": 95,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "m @ torch.ones(4, 1)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 192,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "r = torch.tensor([[2.0, 1.0, -1.0, -2.0]]).T\n",
-    "v = torch.tensor([[200.0, 100.0, -100.0, -200.0]]).T\n",
-    "gamma = 0.9"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 198,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "tensor([[ 5.4054],\n",
-      "        [ 2.7027],\n",
-      "        [-2.7027],\n",
-      "        [-5.4054]])\n",
-      "tensor([[ 5.4054],\n",
-      "        [ 2.7027],\n",
-      "        [-2.7027],\n",
-      "        [-5.4054]])\n",
-      "tensor([[ 5.4054],\n",
-      "        [ 2.7027],\n",
-      "        [-2.7027],\n",
-      "        [-5.4054]])\n",
-      "tensor([[ 5.4054],\n",
-      "        [ 2.7027],\n",
-      "        [-2.7027],\n",
-      "        [-5.4054]])\n",
-      "tensor([[ 5.4054],\n",
-      "        [ 2.7027],\n",
-      "        [-2.7027],\n",
-      "        [-5.4054]])\n",
-      "tensor([[ 5.4054],\n",
-      "        [ 2.7027],\n",
-      "        [-2.7027],\n",
-      "        [-5.4054]])\n",
-      "tensor([[ 5.4054],\n",
-      "        [ 2.7027],\n",
-      "        [-2.7027],\n",
-      "        [-5.4054]])\n",
-      "tensor([[ 5.4054],\n",
-      "        [ 2.7027],\n",
-      "        [-2.7027],\n",
-      "        [-5.4054]])\n",
-      "tensor([[ 5.4054],\n",
-      "        [ 2.7027],\n",
-      "        [-2.7027],\n",
-      "        [-5.4054]])\n",
-      "tensor([[ 5.4054],\n",
-      "        [ 2.7027],\n",
-      "        [-2.7027],\n",
-      "        [-5.4054]])\n"
-     ]
-    }
-   ],
-   "source": [
-    "for i in range(10):\n",
-    "    print(v)\n",
-    "    v = gamma * m @ v + r"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 142,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "m = [\n",
-    "    [0.5, 0.5, 0.0],\n",
-    "    [0.25, 0.5, 0.25],\n",
-    "    [0, 0.5, 0.5]\n",
-    "]"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 143,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "m = torch.tensor(m)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 160,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "n = torch.tensor([[0.0, 0.9, 0.1]])"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 2,
-   "metadata": {},
-   "outputs": [
-    {
-     "ename": "NameError",
-     "evalue": "name 'n' is not defined",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
-      "\u001b[0;32m<ipython-input-2-17b8a0a14bcd>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m     \u001b[0mn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
-      "\u001b[0;31mNameError\u001b[0m: name 'n' is not defined"
-     ]
-    }
-   ],
-   "source": [
-    "for i in range(10):\n",
-    "    print(n)\n",
-    "    n = n @ m\n",
-    "    \n",
-    "    \n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 51,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "import torch\n",
-    "from torch.autograd import Variable"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 161,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 162,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 170,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "(tensor([[1., 0., 0.]], grad_fn=<AddBackward0>),\n",
-       " tensor([[-0.2796,  0.1884,  0.0913]]))"
-      ]
-     },
-     "execution_count": 170,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "logits = torch.tensor([[1., 1., 1.]], requires_grad=True)\n",
-    "y = F.gumbel_softmax(logits, tau=1, hard=True)\n",
-    "(y * torch.arange(3)).sum().backward()\n",
-    "y, logits.grad"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 148,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[1., 1.]], grad_fn=<AddBackward0>)"
-      ]
-     },
-     "execution_count": 148,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "y_soft = F.gumbel_softmax(logits, tau=1, hard=False)\n",
-    "y_soft.max(-1, keepdim=True)[1] - y_soft.detach() + y_soft"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 117,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[1., 0.]], grad_fn=<AddBackward0>)"
-      ]
-     },
-     "execution_count": 117,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "F.gumbel_softmax(logits, tau=1, hard=True)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 81,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def sample_gumbel(shape, eps=1e-20):\n",
-    "    U = torch.rand(shape)\n",
-    "    return -Variable(torch.log(-torch.log(U + eps) + eps))\n",
-    "\n",
-    "def gumbel_softmax_sample(logits, temperature):\n",
-    "    y = logits + sample_gumbel(logits.size())\n",
-    "    return F.softmax(y / temperature, dim=-1)\n",
-    "\n",
-    "def gumbel_softmax(logits, temperature):\n",
-    "    \"\"\"\n",
-    "    ST-gumple-softmax\n",
-    "    input: [*, n_class]\n",
-    "    return: flatten --> [*, n_class] an one-hot vector\n",
-    "    \"\"\"\n",
-    "    y = gumbel_softmax_sample(logits, temperature)\n",
-    "    shape = y.size()\n",
-    "    _, ind = y.max(dim=-1)\n",
-    "    y_hard = torch.zeros_like(y).view(-1, shape[-1])\n",
-    "    y_hard.scatter_(1, ind.view(-1, 1), 1)\n",
-    "    y_hard = y_hard.view(*shape)\n",
-    "    y_hard = (y_hard - y).detach() + y\n",
-    "    return y_hard.view(-1,latent_dim*categorical_dim), y, ind"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 53,
-   "metadata": {},
-   "outputs": [
-    {
-     "ename": "NameError",
-     "evalue": "name 'latent_dim' is not defined",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
-      "\u001b[0;32m<ipython-input-53-56c036ce50bf>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mgumbel_softmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
-      "\u001b[0;32m<ipython-input-52-69711d14ce77>\u001b[0m in \u001b[0;36mgumbel_softmax\u001b[0;34m(logits, temperature)\u001b[0m\n\u001b[1;32m     20\u001b[0m     \u001b[0my_hard\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my_hard\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     21\u001b[0m     \u001b[0my_hard\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0my_hard\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0my_hard\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mlatent_dim\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mcategorical_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
-      "\u001b[0;31mNameError\u001b[0m: name 'latent_dim' is not defined"
-     ]
-    }
-   ],
-   "source": [
-    "gumbel_softmax(torch.randn(4, 3), 0.1)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 55,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "latent_dim = 10\n",
-    "categorical_dim = 3\n",
-    "x = torch.randn(5, latent_dim, categorical_dim)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 85,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "x.requires_grad=True"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 86,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "y, k, ind = gumbel_softmax(x, 0.1)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 87,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[1, 2, 2, 2, 2, 0, 2, 1, 2, 1],\n",
-       "        [0, 2, 0, 0, 2, 1, 0, 0, 0, 0],\n",
-       "        [0, 0, 2, 1, 1, 1, 1, 0, 0, 0],\n",
-       "        [1, 2, 1, 2, 0, 1, 1, 1, 0, 0],\n",
-       "        [2, 0, 0, 1, 0, 2, 0, 1, 0, 1]])"
-      ]
-     },
-     "execution_count": 87,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "ind"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 89,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "torch.return_types.max(\n",
-       "values=tensor([[0.5745, 1.0000, 0.9985, 1.0000, 1.0000, 0.5934, 1.0000, 0.8247, 1.0000,\n",
-       "         0.9999],\n",
-       "        [0.7558, 1.0000, 0.9999, 1.0000, 1.0000, 0.5294, 0.9994, 0.8727, 0.9999,\n",
-       "         0.9213],\n",
-       "        [1.0000, 1.0000, 1.0000, 1.0000, 0.9942, 1.0000, 0.9996, 1.0000, 1.0000,\n",
-       "         1.0000],\n",
-       "        [0.9993, 1.0000, 1.0000, 1.0000, 0.9998, 0.9998, 1.0000, 0.9661, 0.5564,\n",
-       "         1.0000],\n",
-       "        [1.0000, 1.0000, 0.9999, 1.0000, 0.9785, 0.8706, 1.0000, 1.0000, 1.0000,\n",
-       "         0.5304]], grad_fn=<MaxBackward0>),\n",
-       "indices=tensor([[1, 2, 2, 2, 2, 0, 2, 1, 2, 1],\n",
-       "        [0, 2, 0, 0, 2, 1, 0, 0, 0, 0],\n",
-       "        [0, 0, 2, 1, 1, 1, 1, 0, 0, 0],\n",
-       "        [1, 2, 1, 2, 0, 1, 1, 1, 0, 0],\n",
-       "        [2, 0, 0, 1, 0, 2, 0, 1, 0, 1]]))"
-      ]
-     },
-     "execution_count": 89,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "k.max(dim=-1)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 94,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([1, 2, 2, 2, 2, 0, 2, 1, 2, 1])"
-      ]
-     },
-     "execution_count": 94,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "k.max(dim=-1)[1][0]"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 92,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "torch.Size([10])"
-      ]
-     },
-     "execution_count": 92,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "torch.arange(10).shape"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 171,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[1., 0., 0.]], grad_fn=<AddBackward0>)"
-      ]
-     },
-     "execution_count": 171,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "y"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 107,
-   "metadata": {},
-   "outputs": [
-    {
-     "ename": "NameError",
-     "evalue": "name 'nn' is not defined",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
-      "\u001b[0;32m<ipython-input-107-e31dcaa6664b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0ml\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
-      "\u001b[0;31mNameError\u001b[0m: name 'nn' is not defined"
-     ]
-    }
-   ],
-   "source": [
-    "l = nn.embedding(3, 4)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 172,
-   "metadata": {},
-   "outputs": [
-    {
-     "ename": "NameError",
-     "evalue": "name 'nn' is not defined",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
-      "\u001b[0;32m<ipython-input-172-23c43f10cdd2>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mclass\u001b[0m \u001b[0mRewardModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mModule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m         \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
-      "\u001b[0;31mNameError\u001b[0m: name 'nn' is not defined"
-     ]
-    }
-   ],
-   "source": [
-    "class RewardModel(nn.Module):\n",
-    "\n",
-    "    def __init__(self, model):\n",
-    "        super().__init__()\n",
-    "        self.embedding = model\n",
-    "        self.score = nn.Linear(model.embed_dim, 1, bias=False)\n",
-    "\n",
-    "    def forward(self, x, seq_len=None):\n",
-    "        # x:表示文本,形状(B, T, C), seq_len:表示文本长度,形状(B)\n",
-    "        B, T = x.shape\n",
-    "        emb = self.embedding(x).last_hidden_state  # (B, T, C)\n",
-    "        ind = torch.arange(B, device=x.device)\n",
-    "        if seq_len == None:\n",
-    "            seq_len = torch.tensor([T] * B)\n",
-    "        # 获取最后一个词元的特征\n",
-    "        pooled_emb = emb[ind, seq_len - 1]         # (B,    C)\n",
-    "        score = self.score(pooled_emb)             # (B,    1)\n",
-    "        return score\n",
-    "\n",
-    "\n",
-    "r_model = RewardModel(GPT2Model.from_pretrained('gpt2'))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 175,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "x = torch.randn(1, 4, requires_grad=True)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 176,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor(1.7199, grad_fn=<MaxBackward1>)"
-      ]
-     },
-     "execution_count": 176,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "torch.max(x)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 181,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[1]])"
-      ]
-     },
-     "execution_count": 181,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "logits = torch.randn(1, 5, requires_grad=True)\n",
-    "probs = F.softmax(logits, dim=-1)\n",
-    "y = torch.multinomial(probs, num_samples=1)\n",
-    "y"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 182,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[1., 0., 0., 0., 0.]], grad_fn=<AddBackward0>)"
-      ]
-     },
-     "execution_count": 182,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "y_one_hot = F.gumbel_softmax(logits, hard=True)\n",
-    "gumbel_y = torch.argmax(y_one_hot, dim=-1, keepdim=True)\n",
-    "y_one_hot"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 183,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[0]])"
-      ]
-     },
-     "execution_count": 183,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "gumbel_y"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 189,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def log(t, eps = 1e-20):\n",
-    "    return torch.log(t.clamp(min = eps))\n",
-    "\n",
-    "def gumbel_noise(t):\n",
-    "    noise = torch.zeros_like(t).uniform_(0, 1)\n",
-    "    return -log(-log(noise))\n",
-    "\n",
-    "def gumbel_sample(t, temperature = 1., dim = -1):\n",
-    "    return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 190,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([2])"
-      ]
-     },
-     "execution_count": 190,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "gumbel_sample(torch.randn(1, 4, requires_grad=True))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 191,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "B, T, vs = (3, 4, 20)\n",
-    "logits = torch.randn(B, T, vs)\n",
-    "labels = torch.randint(vs, (B, T))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 197,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "(tensor(3.9027), tensor(3.9027))"
-      ]
-     },
-     "execution_count": 197,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "l = F.cross_entropy(logits.view(B * T, vs), labels.view(B * T), reduction='none')\n",
-    "l.mean(), F.cross_entropy(logits.view(B * T, vs), labels.view(B * T))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 203,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[-4.3894, -4.6075, -2.8143, -3.5261],\n",
-       "        [-4.3982, -3.5565, -2.3154, -4.5414],\n",
-       "        [-4.1661, -3.0418, -2.9297, -6.5461]])"
-      ]
-     },
-     "execution_count": 203,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "lnP = -F.cross_entropy(logits.transpose(-2, -1), labels, reduction='none')\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 221,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "lossi = F.cross_entropy(logits.transpose(-2, -1), labels, reduction='none')"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 225,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor(3.9027)"
-      ]
-     },
-     "execution_count": 225,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "lossi.mean(-1).mean(-1)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "tokenizer"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 220,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor(3.9027)"
-      ]
-     },
-     "execution_count": 220,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "(-r * lnP).mean()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 211,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([-15.3374, -14.8116, -16.6836])"
-      ]
-     },
-     "execution_count": 211,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "lnP.sum(-1)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 213,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[0, 0, 0, 0, 1, 1, 1, 0, 1, 1],\n",
-       "        [0, 0, 1, 0, 1, 1, 0, 0, 1, 0],\n",
-       "        [0, 0, 1, 1, 1, 1, 1, 1, 1, 0],\n",
-       "        [0, 1, 1, 1, 1, 1, 0, 0, 0, 1],\n",
-       "        [1, 0, 1, 1, 1, 1, 1, 1, 0, 0],\n",
-       "        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],\n",
-       "        [1, 0, 0, 0, 0, 1, 1, 1, 0, 0],\n",
-       "        [0, 1, 1, 0, 0, 0, 0, 1, 1, 0],\n",
-       "        [0, 1, 1, 1, 1, 1, 0, 0, 1, 0],\n",
-       "        [1, 1, 1, 1, 0, 0, 0, 0, 1, 1],\n",
-       "        [0, 1, 0, 1, 1, 1, 1, 1, 0, 1],\n",
-       "        [0, 1, 0, 0, 1, 1, 0, 1, 0, 1],\n",
-       "        [0, 0, 1, 1, 0, 1, 0, 0, 0, 0],\n",
-       "        [0, 0, 1, 1, 1, 0, 0, 0, 1, 0],\n",
-       "        [0, 1, 1, 0, 0, 1, 0, 1, 1, 0],\n",
-       "        [0, 0, 0, 0, 0, 1, 1, 1, 1, 0],\n",
-       "        [1, 0, 0, 1, 0, 0, 1, 0, 1, 1],\n",
-       "        [1, 1, 1, 0, 0, 0, 0, 0, 0, 1],\n",
-       "        [1, 1, 1, 0, 1, 1, 1, 1, 0, 0],\n",
-       "        [0, 0, 0, 1, 1, 1, 1, 1, 0, 1]])"
-      ]
-     },
-     "execution_count": 213,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "torch.randint(2, (20, 10))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 2,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "import torch"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 23,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "a = torch.tensor([0.5, 0.2, 0.1])\n",
-    "r = torch.tensor([1001., 1002., 1003.])"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 28,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "(tensor([500.5000, 200.4000, 100.3000]), tensor(208.2627))"
-      ]
-     },
-     "execution_count": 28,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "a * r, (a * r).std()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 29,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "(tensor([0.0000, 0.2000, 0.2000]), tensor(0.1155))"
-      ]
-     },
-     "execution_count": 29,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "a * (r-1001), (a * (r-1001)).std()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 22,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([0., 1., 2.])"
-      ]
-     },
-     "execution_count": 22,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "r-1001"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 34,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor(208.1666)"
-      ]
-     },
-     "execution_count": 34,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "(1000 * a).std()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 35,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor(0.2082)"
-      ]
-     },
-     "execution_count": 35,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "(1 * a).std()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 38,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([1001.0000,  400.4000,  200.2000])"
-      ]
-     },
-     "execution_count": 38,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "1000 * a + 1002 * a"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 39,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([1.5000, 0.6000, 0.3000])"
-      ]
-     },
-     "execution_count": 39,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "a + 2 * a"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 64,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "torch.Size([2, 20])"
-      ]
-     },
-     "execution_count": 64,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "g = torch.normal(1, 1, (2, 20))\n",
-    "g.shape"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 78,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "torch.Size([2, 1])"
-      ]
-     },
-     "execution_count": 78,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "r = torch.normal(1, 1, (2, 1))\n",
-    "r.shape"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 79,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[ 0.9881,  0.7762,  1.3812,  0.7958,  1.0094,  0.1070,  0.1709,  0.3961,\n",
-       "         -0.0118,  0.2604,  0.5242,  0.2400,  0.1989,  0.2737,  1.4879,  0.4091,\n",
-       "          0.7199, -0.2781,  0.7418,  1.2079],\n",
-       "        [ 3.4830,  0.1242,  0.8680,  0.0464,  0.8523,  1.0290, -0.3891,  2.5647,\n",
-       "         -0.1333,  3.3386,  0.1980,  0.2184,  0.4268,  4.9361,  3.0060,  1.3053,\n",
-       "          2.5407,  1.2318,  1.8263,  1.0207]])"
-      ]
-     },
-     "execution_count": 79,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "help(g * r)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 86,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Help on method norm in module torch._tensor:\n",
-      "\n",
-      "norm(p: Union[float, str, NoneType] = 'fro', dim=None, keepdim=False, dtype=None) method of torch.Tensor instance\n",
-      "    See :func:`torch.norm`\n",
-      "\n"
-     ]
-    }
-   ],
-   "source": [
-    "help(g.norm)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 59,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor(1.4053)"
-      ]
-     },
-     "execution_count": 59,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "(g * rr).std()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 88,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([39.9750, 37.7648])"
-      ]
-     },
-     "execution_count": 88,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "(g**2).sum(dim=-1)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 179,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "def norm_(g):\n",
-    "    k = g / g.norm(dim=-1, keepdim=True)\n",
-    "    return k"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "import torch"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 2,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor(99.9496)"
-      ]
-     },
-     "execution_count": 2,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "num = 10000\n",
-    "grad = torch.normal(0, 1, (num, 100))\n",
-    "g = torch.normal(100, 1, (num, 1))\n",
-    "(grad * g).std()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 3,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor(1.0004)"
-      ]
-     },
-     "execution_count": 3,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "num = 10000\n",
-    "grad = torch.normal(0, 1, (num, 100))\n",
-    "g = torch.normal(100, 1, (num, 1)) - 100\n",
-    "(grad * g).std()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 261,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor(0.0996)"
-      ]
-     },
-     "execution_count": 261,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "(g * r).mean(-1).std()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 241,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[100.],\n",
-       "        [100.],\n",
-       "        [100.],\n",
-       "        [100.],\n",
-       "        [100.],\n",
-       "        [100.],\n",
-       "        [100.],\n",
-       "        [100.],\n",
-       "        [100.],\n",
-       "        [100.]])"
-      ]
-     },
-     "execution_count": 241,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "torch.normal(100, 0, (10, 1))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  }
- ],
- "metadata": {
-  "kernelspec": {
-   "display_name": "Python 3",
-   "language": "python",
-   "name": "python3"
-  },
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.8.5"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}

+ 0 - 400
ch12_rl/intuition_model-Copy1.ipynb

@@ -1,400 +0,0 @@
-{
- "cells": [
-  {
-   "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "<torch._C.Generator at 0x7f8c38c68110>"
-      ]
-     },
-     "execution_count": 1,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "import torch\n",
-    "import torch.nn as nn\n",
-    "import torch.nn.functional as F\n",
-    "from transformers import AutoTokenizer, GPT2LMHeadModel, GPT2Model\n",
-    "\n",
-    "\n",
-    "torch.manual_seed(12046)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 2,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "llm = GPT2LMHeadModel.from_pretrained('gpt2')\n",
-    "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 3,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "class RewardModel(nn.Module):\n",
-    "\n",
-    "    def __init__(self, model):\n",
-    "        super().__init__()\n",
-    "        self.embedding = model\n",
-    "        self.score = nn.Linear(model.embed_dim, 1, bias=False)\n",
-    "\n",
-    "    def forward(self, x, seq_len=None):\n",
-    "        # x:表示文本,形状(B, T, vs)或者(B, T), seq_len:表示文本长度,形状(B)\n",
-    "        B = x.shape[0]\n",
-    "        T = x.shape[1]\n",
-    "        emb = self.get_last_hidden_state(x)     # (B, T, C)\n",
-    "        ind = torch.arange(B, device=x.device)\n",
-    "        if seq_len == None:\n",
-    "            seq_len = torch.tensor([T] * B)\n",
-    "        # 获取最后一个词元的特征\n",
-    "        pooled_emb = emb[ind, seq_len - 1]      # (B,    C)\n",
-    "        score = self.score(pooled_emb)          # (B,    1)\n",
-    "        return score\n",
-    "    \n",
-    "    def get_last_hidden_state(self, x):\n",
-    "        if len(x.shape) == 2:\n",
-    "            # x shape = (B, T)\n",
-    "            emb = self.embedding(x).last_hidden_state  # (B, T, C)\n",
-    "        # 为后面使用gumbel_softmax做准备,直接与embedding的模型参数进行计算\n",
-    "        else:\n",
-    "            # x shape = (B, T, vs)\n",
-    "            w = self.embedding.get_input_embeddings().weight  # (vs, C)\n",
-    "            inputs_embeds = x @ w  # (B, T, C)\n",
-    "            emb = self.embedding(inputs_embeds=inputs_embeds).last_hidden_state\n",
-    "        return emb\n",
-    "\n",
-    "r_model = RewardModel(GPT2Model.from_pretrained('gpt2'))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 4,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor(0., grad_fn=<MaxBackward1>)"
-      ]
-     },
-     "execution_count": 4,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "# 验证评分模型计算正确\n",
-    "x = torch.randint(0, tokenizer.vocab_size, (3, 4))\n",
-    "x_hot = F.one_hot(x, num_classes=tokenizer.vocab_size).float()\n",
-    "(r_model(x) - r_model(x_hot)).abs().max()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 5,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "class RLModel(nn.Module):\n",
-    "    \n",
-    "    def __init__(self, llm, r_model):\n",
-    "        super().__init__()\n",
-    "        self.llm = llm\n",
-    "        self.r_model = r_model\n",
-    "        # 冻结模型\n",
-    "        for param in r_model.parameters():\n",
-    "            param.requires_grad = False\n",
-    "    \n",
-    "    def generate(self, idx, max_new_tokens):\n",
-    "        model = self.llm\n",
-    "        for _ in range(max_new_tokens):\n",
-    "            logits = model(input_ids=idx).logits\n",
-    "            logits = logits[:, -1, :]\n",
-    "            probs = F.softmax(logits, dim=-1)\n",
-    "            # 根据概率,随机生成下一个词元\n",
-    "            idx_next = torch.multinomial(probs, num_samples=1)\n",
-    "            idx = torch.cat((idx, idx_next), dim=1)\n",
-    "        return idx\n",
-    "    \n",
-    "    def forward(self, idx):\n",
-    "        # 为了代码简洁,我们设置产生文本的长度\n",
-    "        ans = self.generate(idx, 20)\n",
-    "        reward = self.r_model(ans)\n",
-    "        return reward"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 6,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "inputs = '1 + 2 = 3, 2 + 1 = 3, 1 + 2 ='\n",
-    "ids = tokenizer(inputs, return_tensors=\"pt\")\n",
-    "model = RLModel(llm, r_model)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 7,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "1 + 2 = 3, 2 + 1 = 3, 1 + 2 = 4, 3 + 1 = 5, 1 + 2 = 6 — Ha ha ha! In us\n"
-     ]
-    }
-   ],
-   "source": [
-    "# 验证generate是正确的\n",
-    "print(tokenizer.decode(model.generate(ids['input_ids'], 20)[0], skip_special_tokens=True))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 8,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
-      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
-     ]
-    },
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "1 + 2 = 3, 2 + 1 = 3, 1 + 2 = 4 without action FARMADAM (same) Wooden child Servant use Intel SOCKS+\n"
-     ]
-    }
-   ],
-   "source": [
-    "res = model.llm.generate(\n",
-    "    input_ids=ids['input_ids'], max_new_tokens=20,\n",
-    "    do_sample=True, top_k=0)[0]\n",
-    "print(tokenizer.decode(res, skip_special_tokens=True))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 9,
-   "metadata": {},
-   "outputs": [
-    {
-     "ename": "RuntimeError",
-     "evalue": "element 0 of tensors does not require grad and does not have a grad_fn",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
-      "\u001b[0;32m<ipython-input-9-b7dbb844b37d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mids\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'input_ids'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;31m# 将报错\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    485\u001b[0m                 \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    486\u001b[0m             )\n\u001b[0;32m--> 487\u001b[0;31m         torch.autograd.backward(\n\u001b[0m\u001b[1;32m    488\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    489\u001b[0m         )\n",
-      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    198\u001b[0m     \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    199\u001b[0m     \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m    201\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    202\u001b[0m         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass\n",
-      "\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn"
-     ]
-    }
-   ],
-   "source": [
-    "loss = -1 * model(ids['input_ids'])\n",
-    "# 将报错\n",
-    "loss.backward()"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 10,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "tensor([ 928.,  926., 1631.,  340., 6175.])\n",
-      "tensor([ 996.,  865., 1616.,  314., 6209.])\n"
-     ]
-    }
-   ],
-   "source": [
-    "# 实验gumbel_softmax\n",
-    "logits = torch.randn(1, 5)\n",
-    "probs = F.softmax(logits, dim=-1)\n",
-    "y = torch.multinomial(probs, num_samples=10000, replacement=True)\n",
-    "print(torch.histogram(y.float(), bins=5).hist)\n",
-    "gumbel_y = torch.argmax(F.gumbel_softmax(logits.repeat(10000, 1), tau=1, hard=True), dim=-1, keepdim=True)\n",
-    "print(torch.histogram(gumbel_y.float(), bins=5).hist)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 11,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "class RLModelWithGumbel(nn.Module):\n",
-    "    \n",
-    "    def __init__(self, llm, r_model):\n",
-    "        super().__init__()\n",
-    "        self.llm = llm\n",
-    "        self.r_model = r_model\n",
-    "        # 冻结模型\n",
-    "        for param in r_model.parameters():\n",
-    "            param.requires_grad = False\n",
-    "    \n",
-    "    def generate(self, idx, max_new_tokens):\n",
-    "        model = self.llm\n",
-    "        B, T = idx.shape\n",
-    "        ans = None\n",
-    "        for _ in range(max_new_tokens):\n",
-    "            logits = model(input_ids=idx).logits\n",
-    "            logits = logits[:, -1, :]\n",
-    "            # 根据概率,随机生成下一个词元\n",
-    "            idx_next_hot = F.gumbel_softmax(logits, tau=1, hard=True)  # (B, vs)\n",
-    "            idx_next = torch.argmax(idx_next_hot, dim=-1, keepdim=True)\n",
-    "            idx = torch.cat((idx, idx_next.long()), dim=1)\n",
-    "            idx_next_hot = idx_next_hot.unsqueeze(1)      # (B, 1, vs)\n",
-    "            if ans == None:\n",
-    "                ans = idx_next_hot\n",
-    "            else:\n",
-    "                ans = torch.cat((ans, idx_next_hot), dim=1)\n",
-    "        return idx, ans\n",
-    "    \n",
-    "    def forward(self, idx):\n",
-    "        # 为了代码简洁,我们设置产生文本的长度\n",
-    "        _, ans = self.generate(idx, 20)\n",
-    "        reward = self.r_model(ans)\n",
-    "        return reward"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 12,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "model_gumbel = RLModelWithGumbel(llm, r_model)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 13,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "tensor([[True, True, True, True, True, True, True, True, True, True, True, True,\n",
-      "         True, True, True, True, True, True, True, True]])\n",
-      "1 + 2 = 3, 2 + 1 = 3, 1 + 2 = 0, 1 + 1 = 0; extends laugh(cow, decision, discount) fifth person,\n"
-     ]
-    }
-   ],
-   "source": [
-    "# 验证generate正确\n",
-    "idx, ans = model_gumbel.generate(ids['input_ids'], 20)\n",
-    "print(idx[:, ids['input_ids'].shape[1]:] == torch.argmax(ans, dim=-1, keepdim=True).squeeze(-1))\n",
-    "print(tokenizer.decode(idx[0], skip_special_tokens=True))"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 14,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "(tensor([[-0.2085]]), tensor([[-0.2085]], grad_fn=<MmBackward0>))"
-      ]
-     },
-     "execution_count": 14,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "# 验证评分模型计算正确\n",
-    "model_gumbel.r_model(idx[:, ids['input_ids'].shape[1]:]), model_gumbel.r_model(ans)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 15,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "tensor([[ 2.3994e-06,  4.8380e-06,  3.5403e-06,  ...,  4.4225e-06,\n",
-       "         -1.5709e-06,  4.8997e-06],\n",
-       "        [ 4.4208e-05,  1.3246e-04,  1.4072e-05,  ...,  7.9197e-05,\n",
-       "         -1.4321e-06, -6.9506e-06],\n",
-       "        [ 7.8832e-06,  5.7550e-06, -1.3545e-07,  ...,  5.6032e-06,\n",
-       "         -5.2948e-06,  1.6141e-06],\n",
-       "        ...,\n",
-       "        [ 6.0610e-10,  9.2871e-10,  3.8407e-10,  ...,  1.6127e-09,\n",
-       "         -1.6454e-09, -8.2414e-10],\n",
-       "        [-1.5970e-09,  4.7921e-09,  6.8945e-09,  ...,  7.0852e-09,\n",
-       "         -7.1524e-09, -1.9468e-09],\n",
-       "        [ 3.6735e-04,  2.7833e-04,  3.1601e-05,  ...,  1.5014e-05,\n",
-       "          3.1863e-04, -2.6312e-04]])"
-      ]
-     },
-     "execution_count": 15,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "loss = -1 * model_gumbel(ids['input_ids'])\n",
-    "# 成功运行反向传播算法\n",
-    "loss.backward()\n",
-    "list(model_gumbel.llm.parameters())[0].grad"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  }
- ],
- "metadata": {
-  "kernelspec": {
-   "display_name": "Python 3",
-   "language": "python",
-   "name": "python3"
-  },
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.8.5"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}

+ 450 - 82
ch12_rl/llm_ppo.ipynb

@@ -2,16 +2,22 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 2,
-   "metadata": {},
+   "execution_count": 1,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "ZTO7hf-zM-np",
+    "outputId": "ec76c1e3-8c42-4b46-c4d5-a45f5e1029ce"
+   },
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "<torch._C.Generator at 0x7fc57cc63110>"
+       "<torch._C.Generator at 0x790a9c0a93b0>"
       ]
      },
-     "execution_count": 2,
+     "execution_count": 1,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -20,75 +26,181 @@
     "import torch\n",
     "import torch.nn as nn\n",
     "import torch.nn.functional as F\n",
-    "from transformers import AutoModel, AutoTokenizer, GPT2LMHeadModel, GPT2Model\n",
-    "\n",
+    "from torch.nn.utils import clip_grad_norm_\n",
+    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
+    "import torch.optim as optim\n",
+    "from datasets import load_dataset\n",
+    "from transformers import pipeline\n",
     "\n",
     "torch.manual_seed(12046)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
+   "execution_count": 2,
+   "metadata": {
+    "id": "dJhQvyIYM-nr"
+   },
    "outputs": [],
    "source": [
-    "learning_rate = 6e-4\n",
-    "sequence_len = 1024\n",
-    "batch_size = 8\n",
-    "gra_acc_steps = 8 * 2\n",
+    "learning_rate = 1e-4\n",
     "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
-    "eval_iters = 64 * 2\n",
-    "eval_interval = 50"
+    "gamma = 1.0\n",
+    "lambda_ = 0.95\n",
+    "kl_ctl_value = 0.2\n",
+    "cliprange = 0.2\n",
+    "vf_coef = 0.1\n",
+    "mini_batch_size = 20\n",
+    "grad_clip = 1.0"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 12,
-   "metadata": {},
+   "execution_count": 3,
+   "metadata": {
+    "id": "Bh967HO6M-ns"
+   },
    "outputs": [],
    "source": [
-    "llm = GPT2LMHeadModel.from_pretrained('gpt2')\n",
-    "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")"
+    "tokenizer = AutoTokenizer.from_pretrained('gpt2')\n",
+    "tokenizer.pad_token = tokenizer.eos_token"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 11,
-   "metadata": {},
+   "execution_count": 4,
+   "metadata": {
+    "id": "ZZWmZrRlM-ns"
+   },
+   "outputs": [],
+   "source": [
+    "def prepare_input(data):\n",
+    "    data['input_ids'] = [tokenizer.encode(data['text'])[:8]]\n",
+    "    return data\n",
+    "\n",
+    "datasets = load_dataset('imdb', split='train[:500]')\n",
+    "datasets = datasets.filter(lambda x: len(x['text']) > 20)\n",
+    "tokenized = datasets.map(prepare_input, remove_columns=datasets.column_names)\n",
+    "tokenized.set_format(type='torch', device=device)\n",
+    "example = tokenized[1]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {
+    "id": "cH6wexbYM-nw",
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "class A2CLLM(nn.Module):\n",
+    "\n",
+    "    def __init__(self, model):\n",
+    "        super().__init__()\n",
+    "        self.actor = model\n",
+    "        self.critic = nn.Linear(model.base_model.embed_dim, 1, bias=False)\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        _res = self.actor(input_ids=x, output_hidden_states=True)\n",
+    "        logits = _res.logits\n",
+    "        emb = _res.hidden_states[-1]\n",
+    "        values = self.critic(emb).squeeze(-1)\n",
+    "        return logits, values\n",
+    "\n",
+    "    def generate(self, idx, max_new_tokens=20):\n",
+    "        model = self.actor\n",
+    "        return model.generate(idx, max_new_tokens=max_new_tokens,\n",
+    "                             pad_token_id=tokenizer.eos_token_id)\n",
+    "\n",
+    "model = A2CLLM(AutoModelForCausalLM.from_pretrained('lvwerra/gpt2-imdb')).to(device)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {
+    "id": "EO_cOnnCM-nw"
+   },
+   "outputs": [],
+   "source": [
+    "from peft import LoraConfig, PeftModel\n",
+    "\n",
+    "def init_peft_model(model):\n",
+    "    config = LoraConfig(\n",
+    "        r=1,\n",
+    "        lora_alpha=8,\n",
+    "        target_modules=['c_attn'],\n",
+    "        fan_in_fan_out=True,\n",
+    "        bias='none',\n",
+    "        modules_to_save=['critic'])\n",
+    "    return PeftModel(model, config, adapter_name='lora_ppo')\n",
+    "\n",
+    "model = init_peft_model(model)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "tv2HXFQwM-nw",
+    "outputId": "f2ee372f-6a55-498c-8260-8379ae8a6c2f"
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "logits torch.Size([1, 20, 50257])\n",
+      "lnp torch.Size([1, 20])\n",
+      "values torch.Size([1, 20])\n"
+     ]
+    }
+   ],
+   "source": [
+    "def get_forward_result(model, input_ids, response):\n",
+    "    model.eval()\n",
+    "    _, lens = input_ids.shape\n",
+    "    logits, values = model(response)\n",
+    "    lnp = -F.cross_entropy(logits[:, :-1, :].transpose(-2, -1), response[:, 1:], reduction='none')\n",
+    "    res = {\n",
+    "        'logits': logits[:, lens-1:-1, :],\n",
+    "        'lnp': lnp[:, lens-1:],\n",
+    "        'values': values[:, lens:]\n",
+    "    }\n",
+    "    model.train()\n",
+    "    return res\n",
+    "\n",
+    "\n",
+    "input_ids = example['input_ids']\n",
+    "response = model.generate(input_ids)\n",
+    "\n",
+    "example_re = get_forward_result(model, input_ids, response)\n",
+    "for k, v in example_re.items():\n",
+    "    print(k, v.shape)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "wW8OBSb6M-nx",
+    "outputId": "9ce6eba9-9a9a-4ca1-a05b-72d68d25fa32"
+   },
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "GPT2LMHeadModel(\n",
-       "  (transformer): GPT2Model(\n",
-       "    (wte): Embedding(50257, 768)\n",
-       "    (wpe): Embedding(1024, 768)\n",
-       "    (drop): Dropout(p=0.1, inplace=False)\n",
-       "    (h): ModuleList(\n",
-       "      (0-11): 12 x GPT2Block(\n",
-       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
-       "        (attn): GPT2Attention(\n",
-       "          (c_attn): Conv1D()\n",
-       "          (c_proj): Conv1D()\n",
-       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
-       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
-       "        )\n",
-       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
-       "        (mlp): GPT2MLP(\n",
-       "          (c_fc): Conv1D()\n",
-       "          (c_proj): Conv1D()\n",
-       "          (act): NewGELUActivation()\n",
-       "          (dropout): Dropout(p=0.1, inplace=False)\n",
-       "        )\n",
-       "      )\n",
-       "    )\n",
-       "    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
-       "  )\n",
-       "  (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
-       ")"
+       "tensor([0.9959])"
       ]
      },
-     "execution_count": 11,
+     "execution_count": 8,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -96,48 +208,304 @@
    "source": [
     "class RewardModel(nn.Module):\n",
     "\n",
-    "    def __init__(self, model):\n",
+    "    def __init__(self, tokenizer):\n",
     "        super().__init__()\n",
-    "        self.embedding = model\n",
-    "        self.score = nn.Linear(model.embed_dim, 1, bias=False)\n",
-    "\n",
-    "    def forward(self, x, seq_len=None):\n",
-    "        # x:表示文本,形状(B, T, vs)或者(B, T), seq_len:表示文本长度,形状(B)\n",
-    "        B = x.shape[0]\n",
-    "        T = x.shape[1]\n",
-    "        emb = self.get_last_hidden_state(x)     # (B, T, C)\n",
-    "        ind = torch.arange(B, device=x.device)\n",
-    "        if seq_len == None:\n",
-    "            seq_len = torch.tensor([T] * B)\n",
-    "        # 获取最后一个词元的特征\n",
-    "        pooled_emb = emb[ind, seq_len - 1]      # (B,    C)\n",
-    "        score = self.score(pooled_emb)          # (B,    1)\n",
-    "        return score\n",
-    "    \n",
-    "    def get_last_hidden_state(self, x):\n",
-    "        if len(x.shape) == 2:\n",
-    "            # x shape = (B, T)\n",
-    "            emb = self.embedding(x).last_hidden_state  # (B, T, C)\n",
-    "        # 为后面使用gumbel_softmax做准备,直接与embedding的模型参数进行计算\n",
-    "        else:\n",
-    "            # x shape = (B, T, vs)\n",
-    "            w = self.embedding.get_input_embeddings().weight  # (vs, C)\n",
-    "            inputs_embeds = x @ w  # (B, T, C)\n",
-    "            emb = self.embedding(inputs_embeds=inputs_embeds).last_hidden_state\n",
-    "        return emb\n",
-    "\n",
-    "r_model = RewardModel(GPT2Model.from_pretrained('gpt2'))"
+    "        self.model = pipeline(\"sentiment-analysis\", model='lvwerra/distilbert-imdb')\n",
+    "        self.tokenizer = tokenizer\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        re = []\n",
+    "        x = [self.tokenizer.decode(i) for i in x]\n",
+    "        scores = self.model(x)\n",
+    "        for s in scores:\n",
+    "            if s['label'] == 'POSITIVE':\n",
+    "                re.append(s['score'])\n",
+    "            else:\n",
+    "                re.append(1 - s['score'])\n",
+    "        return torch.tensor(re)\n",
+    "\n",
+    "r_model = RewardModel(tokenizer).to(device)\n",
+    "r_model(response)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "1rGQcB0PM-nx",
+    "outputId": "260a7ff9-55ca-41b8-e868-606b710c8b16"
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([1, 20])"
+      ]
+     },
+     "execution_count": 9,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "def compute_rewards(r_model, response, lnp, ref_lnp):\n",
+    "    # scores: (B), lnp: (B, T), ref_lnp: (B, T)\n",
+    "    # r_model:评分模型,response:模型生成的回答\n",
+    "    # lnp:新/旧模型的概率对数,ref_lnp:参考模型的概率对数\n",
+    "    scores = r_model(response)\n",
+    "    rewards = []\n",
+    "    for score, lnprob, ref_lnprob in zip(scores, lnp, ref_lnp):\n",
+    "        kl = lnprob - ref_lnprob\n",
+    "        # kl_ctl_value是调节KL penalty的系数,大于0\n",
+    "        reward = -kl_ctl_value * kl\n",
+    "        # 游戏奖励等于模型评分 + KL penalty\n",
+    "        reward[-1] += score\n",
+    "        rewards.append(reward)\n",
+    "    return torch.stack(rewards)\n",
+    "\n",
+    "with torch.no_grad():\n",
+    "    with model.disable_adapter():\n",
+    "        ref_example_re = get_forward_result(model, input_ids, response)\n",
+    "\n",
+    "rewards = compute_rewards(r_model, response, example_re['lnp'], ref_example_re['lnp'])\n",
+    "rewards.shape"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
+   "execution_count": 10,
+   "metadata": {
+    "id": "huNDSgciM-nx"
+   },
    "outputs": [],
-   "source": []
+   "source": [
+    "class GAE:\n",
+    "\n",
+    "    def __init__(self, gamma, lambda_):\n",
+    "        self.gamma = gamma\n",
+    "        self.lambda_ = lambda_\n",
+    "\n",
+    "    def __call__(self, rewards, values):\n",
+    "        # advantages table\n",
+    "        advantages = []\n",
+    "        last_advantage = 0\n",
+    "        vt_next = 0\n",
+    "        for r, vt in zip(reversed(rewards), reversed(values)):\n",
+    "            delta = r + self.gamma * vt_next - vt\n",
+    "            last_advantage = delta + self.gamma * self.lambda_ * last_advantage\n",
+    "            advantages.insert(0, last_advantage)\n",
+    "            vt_next = vt\n",
+    "\n",
+    "        return torch.stack(advantages)\n",
+    "\n",
+    "gae = GAE(gamma, lambda_)\n",
+    "advantages = gae(rewards, example_re['values'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "MU3Sz6iwM-ny",
+    "outputId": "e96c8950-e9aa-4650-84e9-e799a4e1e9d7"
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(-0.2746, device='cuda:0', grad_fn=<AddBackward0>)"
+      ]
+     },
+     "execution_count": 11,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "def compute_loss(old_lnp, lnp, vpred, advantages):\n",
+    "    # 值函数损失\n",
+    "    vf_loss = -advantages * vpred\n",
+    "    # 策略损失\n",
+    "    ratio = torch.exp(lnp - old_lnp)\n",
+    "    pg_losses = -advantages * ratio\n",
+    "    pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)\n",
+    "    pg_loss = torch.max(pg_losses, pg_losses2)\n",
+    "    # 整体损失\n",
+    "    loss = pg_loss.mean() + vf_coef * vf_loss.mean()\n",
+    "    return loss\n",
+    "\n",
+    "compute_loss(example_re['lnp'], example_re['lnp'], example_re['values'], advantages)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "v_yIwG4tM-nz",
+    "outputId": "7311662e-e1e9-4502-b632-bfab4fba3c39"
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "[tensor([[   40, 26399,   314,  3001,   327, 47269, 20958,    12]],\n",
+       "        device='cuda:0'),\n",
+       " tensor([[    1,    40,  1703, 44269,    25, 12550,     1,   318]],\n",
+       "        device='cuda:0')]"
+      ]
+     },
+     "execution_count": 12,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "def play_game(model, r_model, gae, data):\n",
+    "    model.eval()\n",
+    "    all_input_ids, all_response, all_res, all_advantages = [], [], [], []\n",
+    "    for input_ids in data['input_ids']:\n",
+    "        all_input_ids.append(input_ids)\n",
+    "        # 生成评论\n",
+    "        response = model.generate(input_ids)\n",
+    "        all_response.append(response)\n",
+    "        with torch.no_grad():\n",
+    "            # 记录旧模型数据\n",
+    "            res = get_forward_result(model, input_ids, response)\n",
+    "            all_res.append(res)\n",
+    "            # 记录参考模型数据\n",
+    "            with model.disable_adapter():\n",
+    "                ref_res = get_forward_result(model, input_ids, response)\n",
+    "            rewards = compute_rewards(r_model, response, res['lnp'], ref_res['lnp'])\n",
+    "            all_advantages.append(gae(rewards, res['values']))\n",
+    "    model.train()\n",
+    "    return all_input_ids, all_response, all_res, all_advantages\n",
+    "\n",
+    "play_game(model, r_model, gae, tokenized[:2])[0]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "sPjg11nEM-nz",
+    "outputId": "9ed9c1f9-5726-464c-d618-858bc10deee9"
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "{'score': 0.5244841426610947, 'ref_score': 0.5244841426610947}"
+      ]
+     },
+     "execution_count": 13,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "def estimate_rewards(r_model, model, all_input_ids):\n",
+    "    re = {}\n",
+    "    # 将模型切换至评估模式\n",
+    "    model.eval()\n",
+    "    for input_ids in all_input_ids:\n",
+    "        response = model.generate(input_ids)\n",
+    "        re['score'] = re.get('score', 0) + r_model(response).item()\n",
+    "        with model.disable_adapter():\n",
+    "            response = model.generate(input_ids)\n",
+    "            re['ref_score'] = re.get('ref_score', 0) + r_model(response).item()\n",
+    "    re['score'] /= len(all_input_ids)\n",
+    "    re['ref_score'] /= len(all_input_ids)\n",
+    "    # 将模型切换至训练模式\n",
+    "    model.train()\n",
+    "    return re\n",
+    "\n",
+    "estimate_rewards(r_model, model, tokenized[:20]['input_ids'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "mizmt8YrM-n0",
+    "outputId": "5d53402d-04d4-4dde-e68c-8a2f1c941b77"
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "step    0: score 0.5415, ref_score 0.5085\n",
+      "step    1: score 0.5412, ref_score 0.5085\n",
+      "step    2: score 0.5182, ref_score 0.5085\n",
+      "step    3: score 0.5183, ref_score 0.5085\n",
+      "step    4: score 0.5234, ref_score 0.5085\n",
+      "step    5: score 0.5589, ref_score 0.5085\n",
+      "step    6: score 0.5977, ref_score 0.5085\n",
+      "step    7: score 0.5754, ref_score 0.5085\n",
+      "step    8: score 0.5707, ref_score 0.5085\n",
+      "step    9: score 0.5677, ref_score 0.5085\n",
+      "step   10: score 0.5692, ref_score 0.5085\n",
+      "step   11: score 0.6209, ref_score 0.5085\n",
+      "step   12: score 0.6320, ref_score 0.5085\n",
+      "step   13: score 0.6743, ref_score 0.5085\n",
+      "step   14: score 0.6690, ref_score 0.5085\n",
+      "step   15: score 0.6042, ref_score 0.5085\n",
+      "step   16: score 0.6386, ref_score 0.5085\n",
+      "step   17: score 0.6035, ref_score 0.5085\n",
+      "step   18: score 0.6028, ref_score 0.5085\n",
+      "step   19: score 0.6148, ref_score 0.5085\n",
+      "step   20: score 0.6147, ref_score 0.5085\n",
+      "step   21: score 0.6702, ref_score 0.5085\n",
+      "step   22: score 0.7225, ref_score 0.5085\n",
+      "step   23: score 0.7192, ref_score 0.5085\n"
+     ]
+    }
+   ],
+   "source": [
+    "steps = datasets.num_rows // mini_batch_size\n",
+    "optimizer = optim.AdamW(model.parameters(), lr=learning_rate)\n",
+    "\n",
+    "for s in range(steps-1):\n",
+    "    data = tokenized[s * mini_batch_size: (s + 1) * mini_batch_size]\n",
+    "    # 进行游戏,收集数据。play_game返回的数据都是无法计算梯度的\n",
+    "    # 在play_game中,会基于model生成参考模型\n",
+    "    input_ids, response, old_res, advantages = play_game(model, r_model, gae, data)\n",
+    "    # 循环完成之后,才用新模型替换旧模型\n",
+    "    for _ids, _resp, _old_res, _ad in zip(input_ids, response, old_res, advantages):\n",
+    "        optimizer.zero_grad(set_to_none=True)\n",
+    "        # 收集新模型的数据,model_res里面的数据可以计算梯度\n",
+    "        model_res = get_forward_result(model, _ids, _resp)\n",
+    "        loss = compute_loss(_old_res['lnp'], model_res['lnp'], model_res['values'], _ad)\n",
+    "        loss.backward()\n",
+    "        # 梯度裁剪\n",
+    "        clip_grad_norm_(model.parameters(), grad_clip)\n",
+    "        optimizer.step()\n",
+    "    res = estimate_rewards(r_model, model, tokenized[-mini_batch_size:]['input_ids'])\n",
+    "    print(f'step {s:>4}: score {res[\"score\"]:.4f}, ref_score {res[\"ref_score\"]:.4f}')"
+   ]
   }
  ],
  "metadata": {
+  "accelerator": "GPU",
+  "colab": {
+   "gpuType": "V100",
+   "provenance": []
+  },
   "kernelspec": {
    "display_name": "Python 3",
    "language": "python",
@@ -157,5 +525,5 @@
   }
  },
  "nbformat": 4,
- "nbformat_minor": 4
+ "nbformat_minor": 1
 }

+ 601 - 0
ch12_rl/llm_ppo_correct_dropout.ipynb

@@ -0,0 +1,601 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "ZTO7hf-zM-np",
+    "outputId": "e8944765-badd-4a6c-c111-0ca802c1f23b"
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7e59d435d310>"
+      ]
+     },
+     "execution_count": 1,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import torch.nn.functional as F\n",
+    "from torch.nn.utils import clip_grad_norm_\n",
+    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
+    "import torch.optim as optim\n",
+    "from datasets import load_dataset\n",
+    "from transformers import pipeline\n",
+    "\n",
+    "torch.manual_seed(12046)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {
+    "id": "dJhQvyIYM-nr"
+   },
+   "outputs": [],
+   "source": [
+    "learning_rate = 5e-5\n",
+    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
+    "gamma = 1.0\n",
+    "lambda_ = 0.95\n",
+    "kl_ctl_value = 0.2\n",
+    "cliprange = 0.2\n",
+    "vf_coef = 0.1\n",
+    "mini_batch_size = 20\n",
+    "grad_clip = 1.0"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {
+    "id": "Bh967HO6M-ns"
+   },
+   "outputs": [],
+   "source": [
+    "tokenizer = AutoTokenizer.from_pretrained('gpt2')\n",
+    "tokenizer.pad_token = tokenizer.eos_token"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {
+    "id": "ZZWmZrRlM-ns"
+   },
+   "outputs": [],
+   "source": [
+    "def prepare_input(data):\n",
+    "    data['input_ids'] = [tokenizer.encode(data['text'])[:8]]\n",
+    "    return data\n",
+    "\n",
+    "datasets = load_dataset('imdb', split='train[:500]')\n",
+    "datasets = datasets.filter(lambda x: len(x['text']) > 20)\n",
+    "tokenized = datasets.map(prepare_input, remove_columns=datasets.column_names)\n",
+    "tokenized.set_format(type='torch', device=device)\n",
+    "example = tokenized[1]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {
+    "id": "cH6wexbYM-nw",
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "class A2CLLM(nn.Module):\n",
+    "\n",
+    "    def __init__(self, model):\n",
+    "        super().__init__()\n",
+    "        self.actor = model\n",
+    "        self.critic = nn.Linear(model.base_model.embed_dim, 1, bias=False)\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        _res = self.actor(input_ids=x, output_hidden_states=True)\n",
+    "        logits = _res.logits\n",
+    "        emb = _res.hidden_states[-1]\n",
+    "        values = self.critic(emb).squeeze(-1)\n",
+    "        return logits, values\n",
+    "\n",
+    "    def generate(self, idx, max_new_tokens=20):\n",
+    "        model = self.actor\n",
+    "        return model.generate(idx, max_new_tokens=max_new_tokens,\n",
+    "                             pad_token_id=tokenizer.eos_token_id)\n",
+    "\n",
+    "model = A2CLLM(AutoModelForCausalLM.from_pretrained('lvwerra/gpt2-imdb')).to(device)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {
+    "id": "EO_cOnnCM-nw"
+   },
+   "outputs": [],
+   "source": [
+    "from peft import LoraConfig, PeftModel\n",
+    "\n",
+    "def init_peft_model(model):\n",
+    "    config = LoraConfig(\n",
+    "        r=1,\n",
+    "        lora_alpha=8,\n",
+    "        target_modules=['c_attn'],\n",
+    "        fan_in_fan_out=True,\n",
+    "        lora_dropout=0.1,\n",
+    "        bias='none',\n",
+    "        modules_to_save=['critic'])\n",
+    "    return PeftModel(model, config, adapter_name='lora_ppo')\n",
+    "\n",
+    "model = init_peft_model(model)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "tv2HXFQwM-nw",
+    "outputId": "fbd5f1e8-7cfd-4148-df83-1cc6059958b5"
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "logits torch.Size([1, 20, 50257])\n",
+      "lnp torch.Size([1, 20])\n",
+      "values torch.Size([1, 20])\n"
+     ]
+    }
+   ],
+   "source": [
+    "def get_forward_result(model, input_ids, response):\n",
+    "    _, lens = input_ids.shape\n",
+    "    logits, values = model(response)\n",
+    "    lnp = -F.cross_entropy(logits[:, :-1, :].transpose(-2, -1), response[:, 1:], reduction='none')\n",
+    "    res = {\n",
+    "        'logits': logits[:, lens-1:-1, :],\n",
+    "        'lnp': lnp[:, lens-1:],\n",
+    "        'values': values[:, lens:]\n",
+    "    }\n",
+    "    return res\n",
+    "\n",
+    "\n",
+    "input_ids = example['input_ids']\n",
+    "response = model.generate(input_ids)\n",
+    "\n",
+    "example_re = get_forward_result(model, input_ids, response)\n",
+    "for k, v in example_re.items():\n",
+    "    print(k, v.shape)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "Rkl8rTvBcOQX",
+    "outputId": "01dadb70-33c5-4a79-8991-199b24d80d81"
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "tensor([[ 0.3356, -0.3501, -0.6011, -0.4132,  1.0261,  0.8811, -0.3165,  0.4929,\n",
+      "         -0.9196, -0.3321, -0.2723, -0.1996, -0.6541,  0.1892,  0.6956,  0.3488,\n",
+      "          0.2956,  0.3583,  0.2754,  0.5844,  0.7313,  0.1374,  0.5127, -0.1030,\n",
+      "          0.5666, -0.0081,  0.3219, -0.0353]], device='cuda:0',\n",
+      "       grad_fn=<SubBackward0>)\n",
+      "tensor([[ 0.0418,  0.5579,  0.5273,  1.0549,  0.5402,  0.1473,  0.3205,  0.0311,\n",
+      "          0.6900, -0.2323,  0.1526,  0.4450,  0.1746,  0.6160, -0.2214, -0.1989,\n",
+      "          0.1022,  0.2701, -0.0173, -0.0539, -0.1477,  0.0678, -0.0153, -0.6429,\n",
+      "         -0.3822, -0.4266, -0.2184, -0.4352]], device='cuda:0',\n",
+      "       grad_fn=<SubBackward0>)\n",
+      "tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
+      "         0., 0., 0., 0.]], device='cuda:0', grad_fn=<SubBackward0>)\n"
+     ]
+    }
+   ],
+   "source": [
+    "def turn_on_train_mode(model, target):\n",
+    "    for name, module in model.named_modules():\n",
+    "        if name.split('.')[-1] in target:\n",
+    "            module.train()\n",
+    "    return model\n",
+    "\n",
+    "def _test_turn_on_train_mode():\n",
+    "    test_model = A2CLLM(\n",
+    "        AutoModelForCausalLM.from_pretrained('lvwerra/gpt2-imdb')).to(device)\n",
+    "    config = LoraConfig(\n",
+    "        r=1,\n",
+    "        lora_alpha=8,\n",
+    "        target_modules=['c_attn'],\n",
+    "        fan_in_fan_out=True,\n",
+    "        lora_dropout=0.1,\n",
+    "        bias='none',\n",
+    "        init_lora_weights=False)\n",
+    "    test_model = PeftModel(test_model, config, adapter_name='lora_ppo')\n",
+    "    test_model.train()\n",
+    "    v1 = test_model(response)[1]\n",
+    "    v2 = test_model(response)[1]\n",
+    "    # 不相等\n",
+    "    print(v1 - v2)\n",
+    "\n",
+    "    test_model.eval()\n",
+    "    turn_on_train_mode(test_model, ['c_attn'])\n",
+    "    v1 = test_model(response)[1]\n",
+    "    v2 = test_model(response)[1]\n",
+    "    # 不相等\n",
+    "    print(v1 - v2)\n",
+    "\n",
+    "    test_model.eval()\n",
+    "    turn_on_train_mode(test_model, ['c_attn'])\n",
+    "    with test_model.disable_adapter():\n",
+    "        v1 = test_model(response)[1]\n",
+    "        v2 = test_model(response)[1]\n",
+    "        # 相等\n",
+    "        print(v1 - v2)\n",
+    "\n",
+    "_test_turn_on_train_mode()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "wW8OBSb6M-nx",
+    "outputId": "a2faedc5-e036-4c3e-b29a-0f4acb924448"
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([0.9959])"
+      ]
+     },
+     "execution_count": 9,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "class RewardModel(nn.Module):\n",
+    "\n",
+    "    def __init__(self, tokenizer):\n",
+    "        super().__init__()\n",
+    "        self.model = pipeline(\"sentiment-analysis\", model='lvwerra/distilbert-imdb')\n",
+    "        self.tokenizer = tokenizer\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        re = []\n",
+    "        x = [self.tokenizer.decode(i) for i in x]\n",
+    "        scores = self.model(x)\n",
+    "        for s in scores:\n",
+    "            if s['label'] == 'POSITIVE':\n",
+    "                re.append(s['score'])\n",
+    "            else:\n",
+    "                re.append(1 - s['score'])\n",
+    "        return torch.tensor(re)\n",
+    "\n",
+    "r_model = RewardModel(tokenizer).to(device)\n",
+    "r_model(response)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "1rGQcB0PM-nx",
+    "outputId": "2dfc1902-72f5-43af-9234-b224ebd52959"
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([1, 20])"
+      ]
+     },
+     "execution_count": 10,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "def compute_rewards(r_model, response, lnp, ref_lnp):\n",
+    "    # scores: (B), lnp: (B, T), ref_lnp: (B, T)\n",
+    "    # r_model:评分模型,response:模型生成的回答\n",
+    "    # lnp:新/旧模型的概率对数,ref_lnp:参考模型的概率对数\n",
+    "    scores = r_model(response)\n",
+    "    rewards = []\n",
+    "    for score, lnprob, ref_lnprob in zip(scores, lnp, ref_lnp):\n",
+    "        kl = lnprob - ref_lnprob\n",
+    "        # kl_ctl_value是调节KL penalty的系数,大于0\n",
+    "        reward = -kl_ctl_value * kl\n",
+    "        # 游戏奖励等于模型评分 + KL penalty\n",
+    "        reward[-1] += score\n",
+    "        rewards.append(reward)\n",
+    "    return torch.stack(rewards)\n",
+    "\n",
+    "with torch.no_grad():\n",
+    "    with model.disable_adapter():\n",
+    "        ref_example_re = get_forward_result(model, input_ids, response)\n",
+    "\n",
+    "rewards = compute_rewards(r_model, response, example_re['lnp'], ref_example_re['lnp'])\n",
+    "rewards.shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {
+    "id": "huNDSgciM-nx"
+   },
+   "outputs": [],
+   "source": [
+    "class GAE:\n",
+    "\n",
+    "    def __init__(self, gamma, lambda_):\n",
+    "        self.gamma = gamma\n",
+    "        self.lambda_ = lambda_\n",
+    "\n",
+    "    def __call__(self, rewards, values):\n",
+    "        # advantages table\n",
+    "        advantages = []\n",
+    "        last_advantage = 0\n",
+    "        vt_next = 0\n",
+    "        for r, vt in zip(reversed(rewards), reversed(values)):\n",
+    "            delta = r + self.gamma * vt_next - vt\n",
+    "            last_advantage = delta + self.gamma * self.lambda_ * last_advantage\n",
+    "            advantages.insert(0, last_advantage)\n",
+    "            vt_next = vt\n",
+    "\n",
+    "        return torch.stack(advantages)\n",
+    "\n",
+    "gae = GAE(gamma, lambda_)\n",
+    "advantages = gae(rewards, example_re['values'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "MU3Sz6iwM-ny",
+    "outputId": "c104d182-1547-4596-8092-4d1d117cbe95"
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(-0.2746, device='cuda:0', grad_fn=<AddBackward0>)"
+      ]
+     },
+     "execution_count": 12,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "def compute_loss(old_lnp, lnp, vpred, advantages):\n",
+    "    # 值函数损失\n",
+    "    vf_loss = -advantages * vpred\n",
+    "    # 策略损失\n",
+    "    ratio = torch.exp(lnp - old_lnp)\n",
+    "    pg_losses = -advantages * ratio\n",
+    "    pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)\n",
+    "    pg_loss = torch.max(pg_losses, pg_losses2)\n",
+    "    # 整体损失\n",
+    "    loss = pg_loss.mean() + vf_coef * vf_loss.mean()\n",
+    "    return loss\n",
+    "\n",
+    "compute_loss(example_re['lnp'], example_re['lnp'], example_re['values'], advantages)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "v_yIwG4tM-nz",
+    "outputId": "adb33a65-0b81-4955-f4fb-85be56dc33f3"
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "[tensor([[   40, 26399,   314,  3001,   327, 47269, 20958,    12]],\n",
+       "        device='cuda:0'),\n",
+       " tensor([[    1,    40,  1703, 44269,    25, 12550,     1,   318]],\n",
+       "        device='cuda:0')]"
+      ]
+     },
+     "execution_count": 13,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "def play_game(model, r_model, gae, data):\n",
+    "    model.eval()\n",
+    "    all_input_ids, all_response, all_res, all_advantages = [], [], [], []\n",
+    "    for input_ids in data['input_ids']:\n",
+    "        all_input_ids.append(input_ids)\n",
+    "        # 生成评论\n",
+    "        response = model.generate(input_ids)\n",
+    "        all_response.append(response)\n",
+    "        with torch.no_grad():\n",
+    "            # 记录旧模型数据\n",
+    "            res = get_forward_result(model, input_ids, response)\n",
+    "            all_res.append(res)\n",
+    "            # 记录参考模型数据\n",
+    "            with model.disable_adapter():\n",
+    "                ref_res = get_forward_result(model, input_ids, response)\n",
+    "            rewards = compute_rewards(r_model, response, res['lnp'], ref_res['lnp'])\n",
+    "            all_advantages.append(gae(rewards, res['values']))\n",
+    "    turn_on_train_mode(model, ['c_attn'])\n",
+    "    return all_input_ids, all_response, all_res, all_advantages\n",
+    "\n",
+    "play_game(model, r_model, gae, tokenized[:2])[0]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "sPjg11nEM-nz",
+    "outputId": "bcc0c350-c4bd-4485-85bc-2c718bc6a020"
+   },
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "{'score': 0.5244841426610947, 'ref_score': 0.5244841426610947}"
+      ]
+     },
+     "execution_count": 14,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "def estimate_rewards(r_model, model, all_input_ids):\n",
+    "    re = {}\n",
+    "    # 将模型切换至评估模式\n",
+    "    model.eval()\n",
+    "    for input_ids in all_input_ids:\n",
+    "        response = model.generate(input_ids)\n",
+    "        re['score'] = re.get('score', 0) + r_model(response).item()\n",
+    "        with model.disable_adapter():\n",
+    "            response = model.generate(input_ids)\n",
+    "            re['ref_score'] = re.get('ref_score', 0) + r_model(response).item()\n",
+    "    re['score'] /= len(all_input_ids)\n",
+    "    re['ref_score'] /= len(all_input_ids)\n",
+    "    # 将模型切换至训练模式\n",
+    "    turn_on_train_mode(model, ['c_attn'])\n",
+    "    return re\n",
+    "\n",
+    "estimate_rewards(r_model, model, tokenized[:20]['input_ids'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "mizmt8YrM-n0",
+    "outputId": "f6614c34-3780-4655-821a-061daa007119"
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "step    0: score 0.5412, ref_score 0.5085\n",
+      "step    1: score 0.5412, ref_score 0.5085\n",
+      "step    2: score 0.5085, ref_score 0.5085\n",
+      "step    3: score 0.5412, ref_score 0.5085\n",
+      "step    4: score 0.5180, ref_score 0.5085\n",
+      "step    5: score 0.5182, ref_score 0.5085\n",
+      "step    6: score 0.4743, ref_score 0.5085\n",
+      "step    7: score 0.4743, ref_score 0.5085\n",
+      "step    8: score 0.4741, ref_score 0.5085\n",
+      "step    9: score 0.4741, ref_score 0.5085\n",
+      "step   10: score 0.4725, ref_score 0.5085\n",
+      "step   11: score 0.5210, ref_score 0.5085\n",
+      "step   12: score 0.5225, ref_score 0.5085\n",
+      "step   13: score 0.5168, ref_score 0.5085\n",
+      "step   14: score 0.5184, ref_score 0.5085\n",
+      "step   15: score 0.5135, ref_score 0.5085\n",
+      "step   16: score 0.5147, ref_score 0.5085\n",
+      "step   17: score 0.5129, ref_score 0.5085\n",
+      "step   18: score 0.6062, ref_score 0.5085\n",
+      "step   19: score 0.6182, ref_score 0.5085\n",
+      "step   20: score 0.6737, ref_score 0.5085\n",
+      "step   21: score 0.6730, ref_score 0.5085\n",
+      "step   22: score 0.6731, ref_score 0.5085\n",
+      "step   23: score 0.6724, ref_score 0.5085\n"
+     ]
+    }
+   ],
+   "source": [
+    "steps = datasets.num_rows // mini_batch_size\n",
+    "optimizer = optim.AdamW(model.parameters(), lr=learning_rate)\n",
+    "\n",
+    "for s in range(steps-1):\n",
+    "    data = tokenized[s * mini_batch_size: (s + 1) * mini_batch_size]\n",
+    "    # 进行游戏,收集数据。play_game返回的数据都是无法计算梯度的\n",
+    "    # 在play_game中,会基于model生成参考模型\n",
+    "    input_ids, response, old_res, advantages = play_game(model, r_model, gae, data)\n",
+    "    # 循环完成之后,才用新模型替换旧模型\n",
+    "    for _ids, _resp, _old_res, _ad in zip(input_ids, response, old_res, advantages):\n",
+    "        optimizer.zero_grad(set_to_none=True)\n",
+    "        # 收集新模型的数据,model_res里面的数据可以计算梯度\n",
+    "        model_res = get_forward_result(model, _ids, _resp)\n",
+    "        loss = compute_loss(_old_res['lnp'], model_res['lnp'], model_res['values'], _ad)\n",
+    "        loss.backward()\n",
+    "        # 梯度裁剪\n",
+    "        clip_grad_norm_(model.parameters(), grad_clip)\n",
+    "        optimizer.step()\n",
+    "    res = estimate_rewards(r_model, model, tokenized[-mini_batch_size:]['input_ids'])\n",
+    "    print(f'step {s:>4}: score {res[\"score\"]:.4f}, ref_score {res[\"ref_score\"]:.4f}')"
+   ]
+  }
+ ],
+ "metadata": {
+  "accelerator": "GPU",
+  "colab": {
+   "gpuType": "V100",
+   "provenance": []
+  },
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.5"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}