Gen TANG 2 ani în urmă
părinte
comite
b8c0675877

+ 2994 - 0
ch12_rl/Untitled.ipynb

@@ -0,0 +1,2994 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 115,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n",
+      "\u001b[0mCollecting trl\n",
+      "  Using cached trl-0.7.4-py3-none-any.whl.metadata (10 kB)\n",
+      "Requirement already satisfied: torch>=1.4.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (2.0.1)\n",
+      "Requirement already satisfied: transformers>=4.18.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (4.31.0)\n",
+      "Requirement already satisfied: numpy>=1.18.2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (1.24.4)\n",
+      "Requirement already satisfied: accelerate in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (0.21.0)\n",
+      "Requirement already satisfied: datasets in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (2.14.2)\n",
+      "Requirement already satisfied: tyro>=0.5.11 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (0.5.12)\n",
+      "Requirement already satisfied: filelock in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch>=1.4.0->trl) (3.0.12)\n",
+      "Requirement already satisfied: typing-extensions in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch>=1.4.0->trl) (4.8.0)\n",
+      "Requirement already satisfied: sympy in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch>=1.4.0->trl) (1.6.2)\n",
+      "Requirement already satisfied: networkx in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch>=1.4.0->trl) (2.5)\n",
+      "Requirement already satisfied: jinja2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch>=1.4.0->trl) (2.11.2)\n",
+      "Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (0.16.4)\n",
+      "Requirement already satisfied: packaging>=20.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (23.1)\n",
+      "Requirement already satisfied: pyyaml>=5.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (5.3.1)\n",
+      "Requirement already satisfied: regex!=2019.12.17 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (2020.10.15)\n",
+      "Requirement already satisfied: requests in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (2.24.0)\n",
+      "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (0.13.3)\n",
+      "Requirement already satisfied: safetensors>=0.3.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (0.3.1)\n",
+      "Requirement already satisfied: tqdm>=4.27 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (4.65.0)\n",
+      "Requirement already satisfied: docstring-parser>=0.14.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from tyro>=0.5.11->trl) (0.15)\n",
+      "Requirement already satisfied: rich>=11.1.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from tyro>=0.5.11->trl) (13.6.0)\n",
+      "Requirement already satisfied: shtab>=1.5.6 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from tyro>=0.5.11->trl) (1.6.4)\n",
+      "Requirement already satisfied: psutil in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from accelerate->trl) (5.7.2)\n",
+      "Requirement already satisfied: pyarrow>=8.0.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (12.0.1)\n",
+      "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (0.3.7)\n",
+      "Requirement already satisfied: pandas in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (2.0.3)\n",
+      "Requirement already satisfied: xxhash in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (3.3.0)\n",
+      "Requirement already satisfied: multiprocess in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (0.70.15)\n",
+      "Requirement already satisfied: fsspec>=2021.11.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from fsspec[http]>=2021.11.1->datasets->trl) (2023.6.0)\n",
+      "Requirement already satisfied: aiohttp in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (3.8.5)\n",
+      "Requirement already satisfied: attrs>=17.3.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (23.1.0)\n",
+      "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (3.2.0)\n",
+      "Requirement already satisfied: multidict<7.0,>=4.5 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (6.0.4)\n",
+      "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (4.0.2)\n",
+      "Requirement already satisfied: yarl<2.0,>=1.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (1.9.2)\n",
+      "Requirement already satisfied: frozenlist>=1.1.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (1.4.0)\n",
+      "Requirement already satisfied: aiosignal>=1.1.2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (1.3.1)\n",
+      "Requirement already satisfied: chardet<4,>=3.0.2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers>=4.18.0->trl) (3.0.4)\n",
+      "Requirement already satisfied: idna<3,>=2.5 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers>=4.18.0->trl) (2.10)\n",
+      "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers>=4.18.0->trl) (1.25.11)\n",
+      "Requirement already satisfied: certifi>=2017.4.17 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers>=4.18.0->trl) (2020.6.20)\n",
+      "Requirement already satisfied: markdown-it-py>=2.2.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (3.0.0)\n",
+      "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (2.16.1)\n",
+      "Requirement already satisfied: MarkupSafe>=0.23 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from jinja2->torch>=1.4.0->trl) (1.1.1)\n",
+      "Requirement already satisfied: decorator>=4.3.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from networkx->torch>=1.4.0->trl) (4.4.2)\n",
+      "Requirement already satisfied: python-dateutil>=2.8.2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from pandas->datasets->trl) (2.8.2)\n",
+      "Requirement already satisfied: pytz>=2020.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from pandas->datasets->trl) (2020.1)\n",
+      "Requirement already satisfied: tzdata>=2022.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from pandas->datasets->trl) (2023.3)\n",
+      "Requirement already satisfied: mpmath>=0.19 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from sympy->torch>=1.4.0->trl) (1.1.0)\n",
+      "Requirement already satisfied: mdurl~=0.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro>=0.5.11->trl) (0.1.2)\n",
+      "Requirement already satisfied: six>=1.5 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from python-dateutil>=2.8.2->pandas->datasets->trl) (1.15.0)\n",
+      "Using cached trl-0.7.4-py3-none-any.whl (133 kB)\n",
+      "\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n",
+      "\u001b[0m\u001b[33mDEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063\u001b[0m\u001b[33m\n",
+      "\u001b[0mInstalling collected packages: trl\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Successfully installed trl-0.7.4\r\n"
+     ]
+    }
+   ],
+   "source": [
+    "!pip install --user trl"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import torch\n",
+    "import torch.nn.functional as F"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "a = torch.randn(1, 4, requires_grad=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[ 0.2511, -0.2847, -1.3987,  0.9078]], requires_grad=True)"
+      ]
+     },
+     "execution_count": 8,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "a"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "s = torch.argmax(a)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(3)"
+      ]
+     },
+     "execution_count": 10,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "s"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([ 0., 10.,  3.,  0.], requires_grad=True)"
+      ]
+     },
+     "execution_count": 19,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "weights = torch.tensor([0, 10, 3, 0], dtype=torch.float, requires_grad=True) # create a tensor of weights\n",
+    "weights"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "s = torch.multinomial(weights, 1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 21,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([2])"
+      ]
+     },
+     "execution_count": 21,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "s"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 22,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "RuntimeError",
+     "evalue": "element 0 of tensors does not require grad and does not have a grad_fn",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-22-12626f95b38d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0ms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    485\u001b[0m                 \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    486\u001b[0m             )\n\u001b[0;32m--> 487\u001b[0;31m         torch.autograd.backward(\n\u001b[0m\u001b[1;32m    488\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    489\u001b[0m         )\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    198\u001b[0m     \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    199\u001b[0m     \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m    201\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    202\u001b[0m         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass\n",
+      "\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn"
+     ]
+    }
+   ],
+   "source": [
+    "s.backward()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 226,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import torch\n",
+    "from transformers import AutoTokenizer, GPT2LMHeadModel\n",
+    "\n",
+    "\n",
+    "torch.manual_seed(12046)\n",
+    "\n",
+    "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 227,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "50257"
+      ]
+     },
+     "execution_count": 227,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "tokenizer.vocab_size"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 231,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "1.027730370438185e+47"
+      ]
+     },
+     "execution_count": 231,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import math\n",
+    "\n",
+    "math.pow(50256, 10)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 55,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n",
+      "Input length of input_ids is 7, but `max_length` is set to 1. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.\n"
+     ]
+    }
+   ],
+   "source": [
+    "res = model.generate(**ids, max_length=1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 56,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[2061,  318,  262, 3139,  286, 2807,   30,  198]])"
+      ]
+     },
+     "execution_count": 56,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "res"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 53,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "l = model(**ids).logits"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 54,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[[ -37.2172,  -36.8864,  -40.3563,  ...,  -43.4351,  -43.0232,\n",
+       "           -37.0993],\n",
+       "         [ -84.7418,  -82.6453,  -88.2857,  ...,  -89.4501,  -90.0555,\n",
+       "           -84.6770],\n",
+       "         [ -94.2372,  -92.2747,  -95.0347,  ...,  -94.9362,  -97.9820,\n",
+       "           -91.8550],\n",
+       "         ...,\n",
+       "         [ -75.3586,  -73.5543,  -77.7737,  ...,  -80.5165,  -81.9881,\n",
+       "           -76.4165],\n",
+       "         [ -77.2812,  -75.8471,  -80.8869,  ...,  -88.2672,  -86.0258,\n",
+       "           -79.5727],\n",
+       "         [-123.9086, -123.2837, -124.5704,  ..., -135.1090, -134.9653,\n",
+       "          -119.7716]]], grad_fn=<UnsafeViewBackward0>)"
+      ]
+     },
+     "execution_count": 54,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "l"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 60,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[[ -37.2172,  -36.8864,  -40.3563,  ...,  -43.4351,  -43.0232,\n",
+       "           -37.0993],\n",
+       "         [ -84.7418,  -82.6453,  -88.2857,  ...,  -89.4501,  -90.0555,\n",
+       "           -84.6770],\n",
+       "         [ -94.2372,  -92.2747,  -95.0347,  ...,  -94.9362,  -97.9820,\n",
+       "           -91.8550],\n",
+       "         ...,\n",
+       "         [ -75.3586,  -73.5543,  -77.7737,  ...,  -80.5165,  -81.9881,\n",
+       "           -76.4165],\n",
+       "         [ -77.2812,  -75.8471,  -80.8869,  ...,  -88.2672,  -86.0258,\n",
+       "           -79.5727],\n",
+       "         [-123.9086, -123.2837, -124.5704,  ..., -135.1090, -134.9653,\n",
+       "          -119.7716]]])"
+      ]
+     },
+     "execution_count": 60,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "with torch.no_grad():\n",
+    "    ll = model(**ids).logits\n",
+    "ll"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 87,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "weights = torch.tensor([0, 10, 3, 0], dtype=torch.float, requires_grad=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 88,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([1])"
+      ]
+     },
+     "execution_count": 88,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "y = torch.multinomial(weights, 1)\n",
+    "y"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 89,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "RuntimeError",
+     "evalue": "element 0 of tensors does not require grad and does not have a grad_fn",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-89-ab75bb780f4c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    485\u001b[0m                 \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    486\u001b[0m             )\n\u001b[0;32m--> 487\u001b[0;31m         torch.autograd.backward(\n\u001b[0m\u001b[1;32m    488\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    489\u001b[0m         )\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    198\u001b[0m     \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    199\u001b[0m     \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m    201\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    202\u001b[0m         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass\n",
+      "\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn"
+     ]
+    }
+   ],
+   "source": [
+    "y.backward()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 91,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "<ipython-input-91-26e670d9452d>:1: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
+      "  F.softmax(weights)[y]\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "tensor([0.9990], grad_fn=<IndexBackward0>)"
+      ]
+     },
+     "execution_count": 91,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "F.softmax(weights)[y]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 82,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([2])"
+      ]
+     },
+     "execution_count": 82,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "y"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 86,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "k = (weights == weights[y]).nonzero()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 51,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "ename": "RuntimeError",
+     "evalue": "only Tensors of floating point and complex dtype can require gradients",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-51-e107420eacad>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mids\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'input_ids'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_grad\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+      "\u001b[0;31mRuntimeError\u001b[0m: only Tensors of floating point and complex dtype can require gradients"
+     ]
+    }
+   ],
+   "source": [
+    "ids['input_ids'].requires_grad=True"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 49,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "torch.Tensor"
+      ]
+     },
+     "execution_count": 49,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 50,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "torch.Tensor"
+      ]
+     },
+     "execution_count": 50,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "type(a)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 57,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(1.0545, grad_fn=<SumBackward0>)"
+      ]
+     },
+     "execution_count": 57,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "a = torch.rand(1, 4)\n",
+    "a.requires_grad=True\n",
+    "y = torch.sum(a)\n",
+    "y"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 46,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "torch.sum(a).backward()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 37,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(8903)"
+      ]
+     },
+     "execution_count": 37,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "y = torch.sum(ids['input_ids'])\n",
+    "y"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 38,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "RuntimeError",
+     "evalue": "element 0 of tensors does not require grad and does not have a grad_fn",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-38-ab75bb780f4c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    485\u001b[0m                 \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    486\u001b[0m             )\n\u001b[0;32m--> 487\u001b[0;31m         torch.autograd.backward(\n\u001b[0m\u001b[1;32m    488\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    489\u001b[0m         )\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    198\u001b[0m     \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    199\u001b[0m     \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m    201\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    202\u001b[0m         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass\n",
+      "\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn"
+     ]
+    }
+   ],
+   "source": [
+    "y.backward()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 94,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "x = torch.randn(3, 2, requires_grad=True)\n",
+    "y = torch.ones(3, 2)\n",
+    "x\n",
+    "z = torch.where(x > 0, 1.0, 0.0)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 95,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[0., 1.],\n",
+       "        [1., 1.],\n",
+       "        [0., 0.]])"
+      ]
+     },
+     "execution_count": 95,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "z"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 103,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "t = (x - 0 > 0.001) * 1 + 0"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 100,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[-1.4747,  0.2449],\n",
+       "        [ 0.8045,  0.4111],\n",
+       "        [-0.7769, -0.4431]], requires_grad=True)"
+      ]
+     },
+     "execution_count": 100,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "x"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 104,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[0, 1],\n",
+       "        [1, 1],\n",
+       "        [0, 0]])"
+      ]
+     },
+     "execution_count": 104,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "t"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 106,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[False,  True],\n",
+       "        [ True,  True],\n",
+       "        [False, False]])"
+      ]
+     },
+     "execution_count": 106,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "(x - 0 > 0.001)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 108,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[-1.4747,  0.2449],\n",
+       "        [ 0.8045,  0.4111],\n",
+       "        [-0.7769, -0.4431]], requires_grad=True)"
+      ]
+     },
+     "execution_count": 108,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "x"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 109,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.8045, grad_fn=<MaxBackward1>)"
+      ]
+     },
+     "execution_count": 109,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "torch.max(x)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 117,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "ModuleNotFoundError",
+     "evalue": "No module named 'trl'",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-117-8b0f0cf4af74>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtrl\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mAutoModelForCausalLMWithValueHea\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'trl'"
+     ]
+    }
+   ],
+   "source": [
+    "from trl import AutoModelForCausalLMWithValueHea"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/Users/tgbaggio/.local/lib/python3.8/site-packages/trl/trainer/ppo_config.py:141: UserWarning: The `optimize_cuda_cache` arguement will be deprecated soon, please use `optimize_device_cache` instead.\n",
+      "  warnings.warn(\n"
+     ]
+    }
+   ],
+   "source": [
+    "import trl"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from trl import AutoModelForCausalLMWithValueHead"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c8eb15f8baa64dbfa6a41e4349bc1068",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Downloading (…)lve/main/config.json:   0%|          | 0.00/577 [00:00<?, ?B/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c92409e5bc944693b5578571245093bf",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Downloading pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "model = AutoModelForCausalLMWithValueHead.from_pretrained('lvwerra/gpt2-imdb')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "AutoModelForCausalLMWithValueHead(\n",
+       "  (pretrained_model): GPT2LMHeadModel(\n",
+       "    (transformer): GPT2Model(\n",
+       "      (wte): Embedding(50257, 768)\n",
+       "      (wpe): Embedding(1024, 768)\n",
+       "      (drop): Dropout(p=0.1, inplace=False)\n",
+       "      (h): ModuleList(\n",
+       "        (0-11): 12 x GPT2Block(\n",
+       "          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
+       "          (attn): GPT2Attention(\n",
+       "            (c_attn): Conv1D()\n",
+       "            (c_proj): Conv1D()\n",
+       "            (attn_dropout): Dropout(p=0.1, inplace=False)\n",
+       "            (resid_dropout): Dropout(p=0.1, inplace=False)\n",
+       "          )\n",
+       "          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
+       "          (mlp): GPT2MLP(\n",
+       "            (c_fc): Conv1D()\n",
+       "            (c_proj): Conv1D()\n",
+       "            (act): NewGELUActivation()\n",
+       "            (dropout): Dropout(p=0.1, inplace=False)\n",
+       "          )\n",
+       "        )\n",
+       "      )\n",
+       "      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
+       "    )\n",
+       "    (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
+       "  )\n",
+       "  (v_head): ValueHead(\n",
+       "    (dropout): Dropout(p=0.1, inplace=False)\n",
+       "    (summary): Linear(in_features=768, out_features=1, bias=True)\n",
+       "    (flatten): Flatten(start_dim=1, end_dim=-1)\n",
+       "  )\n",
+       ")"
+      ]
+     },
+     "execution_count": 5,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "l = model(**ids)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([1, 7, 50257])"
+      ]
+     },
+     "execution_count": 13,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "l[0].shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([1, 7])"
+      ]
+     },
+     "execution_count": 15,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "l[2].shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from transformers import AutoTokenizer, pipeline"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "a14f54b9e20d4fa2bf97d5d707531fcf",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Downloading (…)lve/main/config.json:   0%|          | 0.00/735 [00:00<?, ?B/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "1397323a065743a19dc8e235b0a49018",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Downloading pytorch_model.bin:   0%|          | 0.00/268M [00:00<?, ?B/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "898a7c9e19d549459bd8a92006ae3f2b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Downloading (…)okenizer_config.json:   0%|          | 0.00/333 [00:00<?, ?B/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "123feae78d4f4bc182d7405985f33462",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "fbf7ebeb54244310bfe2d225e1441db4",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "7170262df42543be877e56798b06a64a",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\n",
+      "pip install xformers.\n"
+     ]
+    }
+   ],
+   "source": [
+    "p = pipeline('sentiment-analysis', 'lvwerra/distilbert-imdb')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 21,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "sent_kwargs = {\"return_all_scores\": True, \"function_to_apply\": \"none\", \"batch_size\": 16}"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 29,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "{'input_ids': tensor([[2061,  318,  262, 3139,  286, 2807,   30]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}"
+      ]
+     },
+     "execution_count": 29,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "question = 'What is the capital of China?'\n",
+    "ids = tokenizer(question, return_tensors=\"pt\")\n",
+    "ids"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 30,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pipe_outputs = p(question, **sent_kwargs)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 31,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "[[{'label': 'NEGATIVE', 'score': 0.28462526202201843},\n",
+       "  {'label': 'POSITIVE', 'score': -0.498414009809494}]]"
+      ]
+     },
+     "execution_count": 31,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "pipe_outputs"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 32,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "score = torch.randn(3)\n",
+    "log = torch.randn(3, 7)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 45,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "tensor(2.3532) torch.Size([7])\n",
+      "tensor(-0.3089) torch.Size([7])\n",
+      "tensor(-1.0061) torch.Size([7])\n"
+     ]
+    }
+   ],
+   "source": [
+    "for s, l in zip(score, log):\n",
+    "    print(s, l.shape)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 46,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "x = torch.randn(4)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 47,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([ 0.8624,  0.1586, -1.3384,  0.4741])"
+      ]
+     },
+     "execution_count": 47,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "x"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 48,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "x += 1"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 49,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([ 1.8624,  1.1586, -0.3384,  1.4741])"
+      ]
+     },
+     "execution_count": 49,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "x"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 51,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "x[2] += 1"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 55,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[-0.0498,  0.3148, -0.6248,  0.0660],\n",
+       "        [ 1.1840, -1.0514, -0.8765, -0.0313],\n",
+       "        [ 2.3417, -0.0704, -0.0996,  1.7966]])"
+      ]
+     },
+     "execution_count": 55,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "x = torch.randn(3, 4)\n",
+    "x"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 56,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "ValueError",
+     "evalue": "step must be greater than zero",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-56-fc433e12835a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+      "\u001b[0;31mValueError\u001b[0m: step must be greater than zero"
+     ]
+    }
+   ],
+   "source": [
+    "x[::-1]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 79,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "self_config_gamma = 0.8\n",
+    "self_config_lam = 0.9\n",
+    "\n",
+    "def compute_advantages(\n",
+    "    values: torch.FloatTensor,\n",
+    "    rewards: torch.FloatTensor):\n",
+    "    lastgaelam = 0\n",
+    "    advantages_reversed = []\n",
+    "    gen_len = rewards.shape[-1]\n",
+    "    for t in reversed(range(gen_len)):\n",
+    "        nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0\n",
+    "        delta = rewards[:, t] + self_config_gamma * nextvalues - values[:, t]\n",
+    "        lastgaelam = delta + self_config_gamma * self_config_lam * lastgaelam\n",
+    "        advantages_reversed.append(lastgaelam)\n",
+    "    advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)\n",
+    "\n",
+    "    returns = advantages + values\n",
+    "    advantages = advantages.detach()\n",
+    "    return values, advantages, returns\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 80,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "values = torch.tensor([[3., 4., 2.]])\n",
+    "rewards = torch.tensor([[-1., 2., 2.]])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 81,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(tensor([[3., 4., 2.]]),\n",
+       " tensor([[-1.0880, -0.4000,  0.0000]]),\n",
+       " tensor([[1.9120, 3.6000, 2.0000]]))"
+      ]
+     },
+     "execution_count": 81,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "compute_advantages(values, rewards)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "1.8800000000000003"
+      ]
+     },
+     "execution_count": 1,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "0.8 * 2 + 0.8 ** 2 * 2 - 1"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import gym\n",
+    "import pygame\n",
+    "\n",
+    "env = gym.make('MountainCar-v0', render_mode=\"human\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "The observation space: Box([-1.2  -0.07], [0.6  0.07], (2,), float32)\n",
+      "The action space: Discrete(3)\n"
+     ]
+    }
+   ],
+   "source": [
+    "# Observation and action space \n",
+    "obs_space = env.observation_space\n",
+    "action_space = env.action_space\n",
+    "print(\"The observation space: {}\".format(obs_space))\n",
+    "print(\"The action space: {}\".format(action_space))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "The initial observation is (array([-0.46305716,  0.        ], dtype=float32), {})\n",
+      "0\n",
+      "The new observation is [-0.46450874 -0.00145157]\n"
+     ]
+    }
+   ],
+   "source": [
+    "import matplotlib.pyplot as plt \n",
+    "\n",
+    "# reset the environment and see the initial observation\n",
+    "obs = env.reset()\n",
+    "print(\"The initial observation is {}\".format(obs))\n",
+    "\n",
+    "# Sample a random action from the entire action space\n",
+    "random_action = env.action_space.sample()\n",
+    "print(random_action)\n",
+    "\n",
+    "# # Take the action and get the new observation space\n",
+    "new_obs, reward, done, _, info = env.step(random_action)\n",
+    "env.render() \n",
+    "pygame.display.update()\n",
+    "print(\"The new observation is {}\".format(new_obs))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "2\n",
+      "The new observation is [-5.2235210e-01 -4.3133463e-04]\n"
+     ]
+    }
+   ],
+   "source": [
+    "# Sample a random action from the entire action space\n",
+    "random_action = env.action_space.sample()\n",
+    "print(random_action)\n",
+    "\n",
+    "# # Take the action and get the new observation space\n",
+    "new_obs, reward, done, _, info = env.step(random_action)\n",
+    "print(\"The new observation is {}\".format(new_obs))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "env.close()\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "<function pygame.event.get>"
+      ]
+     },
+     "execution_count": 19,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import pygame\n",
+    "\n",
+    "pygame.event.get"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "DependencyNotInstalled",
+     "evalue": "Box2D is not installed, run `pip install gymnasium[box2d]`",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/gymnasium/envs/box2d/bipedal_walker.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     14\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m     \u001b[0;32mimport\u001b[0m \u001b[0mBox2D\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     16\u001b[0m     from Box2D.b2 import (\n",
+      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'Box2D'",
+      "\nThe above exception was the direct cause of the following exception:\n",
+      "\u001b[0;31mDependencyNotInstalled\u001b[0m                    Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-1-c5175ad63d7c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mgymnasium\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"LunarLander-v2\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrender_mode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"human\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m \u001b[0mobservation\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minfo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/gymnasium/envs/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(id, max_episode_steps, autoreset, apply_api_compatibility, disable_env_checker, **kwargs)\u001b[0m\n\u001b[1;32m    754\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    755\u001b[0m         \u001b[0;31m# Assume it's a string\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 756\u001b[0;31m         \u001b[0menv_creator\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_env_creator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv_spec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mentry_point\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    757\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    758\u001b[0m     \u001b[0;31m# Determine if to use the rendering\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/gymnasium/envs/registration.py\u001b[0m in \u001b[0;36mload_env_creator\u001b[0;34m(name)\u001b[0m\n\u001b[1;32m    543\u001b[0m     \"\"\"\n\u001b[1;32m    544\u001b[0m     \u001b[0mmod_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattr_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\":\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 545\u001b[0;31m     \u001b[0mmod\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimportlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimport_module\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmod_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    546\u001b[0m     \u001b[0mfn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattr_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    547\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/__init__.py\u001b[0m in \u001b[0;36mimport_module\u001b[0;34m(name, package)\u001b[0m\n\u001b[1;32m    125\u001b[0m                 \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    126\u001b[0m             \u001b[0mlevel\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 127\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0m_bootstrap\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_gcd_import\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlevel\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpackage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlevel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    128\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    129\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_gcd_import\u001b[0;34m(name, package, level)\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load\u001b[0;34m(name, import_)\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load_unlocked\u001b[0;34m(name, import_)\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_call_with_frames_removed\u001b[0;34m(f, *args, **kwds)\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_gcd_import\u001b[0;34m(name, package, level)\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load\u001b[0;34m(name, import_)\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load_unlocked\u001b[0;34m(name, import_)\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_load_unlocked\u001b[0;34m(spec)\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap_external.py\u001b[0m in \u001b[0;36mexec_module\u001b[0;34m(self, module)\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_call_with_frames_removed\u001b[0;34m(f, *args, **kwds)\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/gymnasium/envs/box2d/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mgymnasium\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbox2d\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbipedal_walker\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mBipedalWalker\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mBipedalWalkerHardcore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgymnasium\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbox2d\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcar_racing\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCarRacing\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgymnasium\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbox2d\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlunar_lander\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mLunarLander\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLunarLanderContinuous\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/gymnasium/envs/box2d/bipedal_walker.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     23\u001b[0m     )\n\u001b[1;32m     24\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m     raise DependencyNotInstalled(\n\u001b[0m\u001b[1;32m     26\u001b[0m         \u001b[0;34m\"Box2D is not installed, run `pip install gymnasium[box2d]`\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     27\u001b[0m     ) from e\n",
+      "\u001b[0;31mDependencyNotInstalled\u001b[0m: Box2D is not installed, run `pip install gymnasium[box2d]`"
+     ]
+    }
+   ],
+   "source": [
+    "import gymnasium as gym\n",
+    "\n",
+    "env = gym.make(\"LunarLander-v2\", render_mode=\"human\")\n",
+    "observation, info = env.reset()\n",
+    "\n",
+    "for _ in range(1000):\n",
+    "    action = env.action_space.sample()  # agent policy that uses the observation and info\n",
+    "    observation, reward, terminated, truncated, info = env.step(action)\n",
+    "\n",
+    "    if terminated or truncated:\n",
+    "        observation, info = env.reset()\n",
+    "\n",
+    "env.close()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "<contextlib.ExitStack at 0x7fbffd7b04c0>"
+      ]
+     },
+     "execution_count": 2,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import gymnasium as gym\n",
+    "import math\n",
+    "import random\n",
+    "import matplotlib\n",
+    "import matplotlib.pyplot as plt\n",
+    "env = gym.make(\"CartPole-v1\")\n",
+    "\n",
+    "# set up matplotlib\n",
+    "is_ipython = 'inline' in matplotlib.get_backend()\n",
+    "if is_ipython:\n",
+    "    from IPython import display\n",
+    "\n",
+    "plt.ion()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 52,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import gymnasium as gym\n",
+    "import pygame\n",
+    "\n",
+    "env = gym.make(\"CartPole-v1\", render_mode=\"human\")\n",
+    "observation, info = env.reset(seed=42)\n",
+    "#for _ in range(100):\n",
+    "#    action = env.action_space.sample()  # this is where you would insert your policy\n",
+    "#    observation, reward, terminated, truncated, info = env.step(action)\n",
+    "#    if terminated or truncated:\n",
+    "#        break\n",
+    "#env.close()\n",
+    "\n",
+    "#pygame.display.update()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 62,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(array([ 0.20159529,  1.9464185 , -0.22034578, -2.9908078 ], dtype=float32),\n",
+       " 1.0,\n",
+       " True,\n",
+       " False,\n",
+       " {})"
+      ]
+     },
+     "execution_count": 62,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "action = env.action_space.sample()  # this is where you would insert your policy\n",
+    "env.step(1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 63,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "env.close()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 64,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "0.075"
+      ]
+     },
+     "execution_count": 64,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "0.3/4"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 163,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "m = [\n",
+    "    [0.775, 0.075, 0.075, 0.075],\n",
+    "    [0.075, 0.775, 0.075, 0.075],\n",
+    "    [0.075, 0.075, 0.775, 0.075],\n",
+    "    [0.075, 0.075, 0.075, 0.775]\n",
+    "]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 164,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import torch\n",
+    "m = torch.tensor(m)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 95,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[1.],\n",
+       "        [1.],\n",
+       "        [1.],\n",
+       "        [1.]])"
+      ]
+     },
+     "execution_count": 95,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "m @ torch.ones(4, 1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 192,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "r = torch.tensor([[2.0, 1.0, -1.0, -2.0]]).T\n",
+    "v = torch.tensor([[200.0, 100.0, -100.0, -200.0]]).T\n",
+    "gamma = 0.9"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 198,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "tensor([[ 5.4054],\n",
+      "        [ 2.7027],\n",
+      "        [-2.7027],\n",
+      "        [-5.4054]])\n",
+      "tensor([[ 5.4054],\n",
+      "        [ 2.7027],\n",
+      "        [-2.7027],\n",
+      "        [-5.4054]])\n",
+      "tensor([[ 5.4054],\n",
+      "        [ 2.7027],\n",
+      "        [-2.7027],\n",
+      "        [-5.4054]])\n",
+      "tensor([[ 5.4054],\n",
+      "        [ 2.7027],\n",
+      "        [-2.7027],\n",
+      "        [-5.4054]])\n",
+      "tensor([[ 5.4054],\n",
+      "        [ 2.7027],\n",
+      "        [-2.7027],\n",
+      "        [-5.4054]])\n",
+      "tensor([[ 5.4054],\n",
+      "        [ 2.7027],\n",
+      "        [-2.7027],\n",
+      "        [-5.4054]])\n",
+      "tensor([[ 5.4054],\n",
+      "        [ 2.7027],\n",
+      "        [-2.7027],\n",
+      "        [-5.4054]])\n",
+      "tensor([[ 5.4054],\n",
+      "        [ 2.7027],\n",
+      "        [-2.7027],\n",
+      "        [-5.4054]])\n",
+      "tensor([[ 5.4054],\n",
+      "        [ 2.7027],\n",
+      "        [-2.7027],\n",
+      "        [-5.4054]])\n",
+      "tensor([[ 5.4054],\n",
+      "        [ 2.7027],\n",
+      "        [-2.7027],\n",
+      "        [-5.4054]])\n"
+     ]
+    }
+   ],
+   "source": [
+    "for i in range(10):\n",
+    "    print(v)\n",
+    "    v = gamma * m @ v + r"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 142,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "m = [\n",
+    "    [0.5, 0.5, 0.0],\n",
+    "    [0.25, 0.5, 0.25],\n",
+    "    [0, 0.5, 0.5]\n",
+    "]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 143,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "m = torch.tensor(m)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 160,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "n = torch.tensor([[0.0, 0.9, 0.1]])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "NameError",
+     "evalue": "name 'n' is not defined",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-2-17b8a0a14bcd>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m     \u001b[0mn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;31mNameError\u001b[0m: name 'n' is not defined"
+     ]
+    }
+   ],
+   "source": [
+    "for i in range(10):\n",
+    "    print(n)\n",
+    "    n = n @ m\n",
+    "    \n",
+    "    \n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 51,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import torch\n",
+    "from torch.autograd import Variable"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 161,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 162,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 170,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(tensor([[1., 0., 0.]], grad_fn=<AddBackward0>),\n",
+       " tensor([[-0.2796,  0.1884,  0.0913]]))"
+      ]
+     },
+     "execution_count": 170,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "logits = torch.tensor([[1., 1., 1.]], requires_grad=True)\n",
+    "y = F.gumbel_softmax(logits, tau=1, hard=True)\n",
+    "(y * torch.arange(3)).sum().backward()\n",
+    "y, logits.grad"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 148,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[1., 1.]], grad_fn=<AddBackward0>)"
+      ]
+     },
+     "execution_count": 148,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "y_soft = F.gumbel_softmax(logits, tau=1, hard=False)\n",
+    "y_soft.max(-1, keepdim=True)[1] - y_soft.detach() + y_soft"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 117,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[1., 0.]], grad_fn=<AddBackward0>)"
+      ]
+     },
+     "execution_count": 117,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "F.gumbel_softmax(logits, tau=1, hard=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 81,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def sample_gumbel(shape, eps=1e-20):\n",
+    "    U = torch.rand(shape)\n",
+    "    return -Variable(torch.log(-torch.log(U + eps) + eps))\n",
+    "\n",
+    "def gumbel_softmax_sample(logits, temperature):\n",
+    "    y = logits + sample_gumbel(logits.size())\n",
+    "    return F.softmax(y / temperature, dim=-1)\n",
+    "\n",
+    "def gumbel_softmax(logits, temperature):\n",
+    "    \"\"\"\n",
+    "    ST-gumple-softmax\n",
+    "    input: [*, n_class]\n",
+    "    return: flatten --> [*, n_class] an one-hot vector\n",
+    "    \"\"\"\n",
+    "    y = gumbel_softmax_sample(logits, temperature)\n",
+    "    shape = y.size()\n",
+    "    _, ind = y.max(dim=-1)\n",
+    "    y_hard = torch.zeros_like(y).view(-1, shape[-1])\n",
+    "    y_hard.scatter_(1, ind.view(-1, 1), 1)\n",
+    "    y_hard = y_hard.view(*shape)\n",
+    "    y_hard = (y_hard - y).detach() + y\n",
+    "    return y_hard.view(-1,latent_dim*categorical_dim), y, ind"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 53,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "NameError",
+     "evalue": "name 'latent_dim' is not defined",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-53-56c036ce50bf>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mgumbel_softmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+      "\u001b[0;32m<ipython-input-52-69711d14ce77>\u001b[0m in \u001b[0;36mgumbel_softmax\u001b[0;34m(logits, temperature)\u001b[0m\n\u001b[1;32m     20\u001b[0m     \u001b[0my_hard\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my_hard\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     21\u001b[0m     \u001b[0my_hard\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0my_hard\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0my_hard\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mlatent_dim\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mcategorical_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+      "\u001b[0;31mNameError\u001b[0m: name 'latent_dim' is not defined"
+     ]
+    }
+   ],
+   "source": [
+    "gumbel_softmax(torch.randn(4, 3), 0.1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 55,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "latent_dim = 10\n",
+    "categorical_dim = 3\n",
+    "x = torch.randn(5, latent_dim, categorical_dim)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 85,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "x.requires_grad=True"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 86,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "y, k, ind = gumbel_softmax(x, 0.1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 87,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[1, 2, 2, 2, 2, 0, 2, 1, 2, 1],\n",
+       "        [0, 2, 0, 0, 2, 1, 0, 0, 0, 0],\n",
+       "        [0, 0, 2, 1, 1, 1, 1, 0, 0, 0],\n",
+       "        [1, 2, 1, 2, 0, 1, 1, 1, 0, 0],\n",
+       "        [2, 0, 0, 1, 0, 2, 0, 1, 0, 1]])"
+      ]
+     },
+     "execution_count": 87,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "ind"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 89,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "torch.return_types.max(\n",
+       "values=tensor([[0.5745, 1.0000, 0.9985, 1.0000, 1.0000, 0.5934, 1.0000, 0.8247, 1.0000,\n",
+       "         0.9999],\n",
+       "        [0.7558, 1.0000, 0.9999, 1.0000, 1.0000, 0.5294, 0.9994, 0.8727, 0.9999,\n",
+       "         0.9213],\n",
+       "        [1.0000, 1.0000, 1.0000, 1.0000, 0.9942, 1.0000, 0.9996, 1.0000, 1.0000,\n",
+       "         1.0000],\n",
+       "        [0.9993, 1.0000, 1.0000, 1.0000, 0.9998, 0.9998, 1.0000, 0.9661, 0.5564,\n",
+       "         1.0000],\n",
+       "        [1.0000, 1.0000, 0.9999, 1.0000, 0.9785, 0.8706, 1.0000, 1.0000, 1.0000,\n",
+       "         0.5304]], grad_fn=<MaxBackward0>),\n",
+       "indices=tensor([[1, 2, 2, 2, 2, 0, 2, 1, 2, 1],\n",
+       "        [0, 2, 0, 0, 2, 1, 0, 0, 0, 0],\n",
+       "        [0, 0, 2, 1, 1, 1, 1, 0, 0, 0],\n",
+       "        [1, 2, 1, 2, 0, 1, 1, 1, 0, 0],\n",
+       "        [2, 0, 0, 1, 0, 2, 0, 1, 0, 1]]))"
+      ]
+     },
+     "execution_count": 89,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "k.max(dim=-1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 94,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([1, 2, 2, 2, 2, 0, 2, 1, 2, 1])"
+      ]
+     },
+     "execution_count": 94,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "k.max(dim=-1)[1][0]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 92,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([10])"
+      ]
+     },
+     "execution_count": 92,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "torch.arange(10).shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 171,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[1., 0., 0.]], grad_fn=<AddBackward0>)"
+      ]
+     },
+     "execution_count": 171,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "y"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 107,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "NameError",
+     "evalue": "name 'nn' is not defined",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-107-e31dcaa6664b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0ml\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+      "\u001b[0;31mNameError\u001b[0m: name 'nn' is not defined"
+     ]
+    }
+   ],
+   "source": [
+    "l = nn.embedding(3, 4)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 172,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "NameError",
+     "evalue": "name 'nn' is not defined",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-172-23c43f10cdd2>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mclass\u001b[0m \u001b[0mRewardModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mModule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m         \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;31mNameError\u001b[0m: name 'nn' is not defined"
+     ]
+    }
+   ],
+   "source": [
+    "class RewardModel(nn.Module):\n",
+    "\n",
+    "    def __init__(self, model):\n",
+    "        super().__init__()\n",
+    "        self.embedding = model\n",
+    "        self.score = nn.Linear(model.embed_dim, 1, bias=False)\n",
+    "\n",
+    "    def forward(self, x, seq_len=None):\n",
+    "        # x:表示文本,形状(B, T, C), seq_len:表示文本长度,形状(B)\n",
+    "        B, T = x.shape\n",
+    "        emb = self.embedding(x).last_hidden_state  # (B, T, C)\n",
+    "        ind = torch.arange(B, device=x.device)\n",
+    "        if seq_len == None:\n",
+    "            seq_len = torch.tensor([T] * B)\n",
+    "        # 获取最后一个词元的特征\n",
+    "        pooled_emb = emb[ind, seq_len - 1]         # (B,    C)\n",
+    "        score = self.score(pooled_emb)             # (B,    1)\n",
+    "        return score\n",
+    "\n",
+    "\n",
+    "r_model = RewardModel(GPT2Model.from_pretrained('gpt2'))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 175,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "x = torch.randn(1, 4, requires_grad=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 176,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(1.7199, grad_fn=<MaxBackward1>)"
+      ]
+     },
+     "execution_count": 176,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "torch.max(x)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 181,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[1]])"
+      ]
+     },
+     "execution_count": 181,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "logits = torch.randn(1, 5, requires_grad=True)\n",
+    "probs = F.softmax(logits, dim=-1)\n",
+    "y = torch.multinomial(probs, num_samples=1)\n",
+    "y"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 182,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[1., 0., 0., 0., 0.]], grad_fn=<AddBackward0>)"
+      ]
+     },
+     "execution_count": 182,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "y_one_hot = F.gumbel_softmax(logits, hard=True)\n",
+    "gumbel_y = torch.argmax(y_one_hot, dim=-1, keepdim=True)\n",
+    "y_one_hot"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 183,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[0]])"
+      ]
+     },
+     "execution_count": 183,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "gumbel_y"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 189,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def log(t, eps = 1e-20):\n",
+    "    return torch.log(t.clamp(min = eps))\n",
+    "\n",
+    "def gumbel_noise(t):\n",
+    "    noise = torch.zeros_like(t).uniform_(0, 1)\n",
+    "    return -log(-log(noise))\n",
+    "\n",
+    "def gumbel_sample(t, temperature = 1., dim = -1):\n",
+    "    return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 190,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([2])"
+      ]
+     },
+     "execution_count": 190,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "gumbel_sample(torch.randn(1, 4, requires_grad=True))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 191,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "B, T, vs = (3, 4, 20)\n",
+    "logits = torch.randn(B, T, vs)\n",
+    "labels = torch.randint(vs, (B, T))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 197,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(tensor(3.9027), tensor(3.9027))"
+      ]
+     },
+     "execution_count": 197,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "l = F.cross_entropy(logits.view(B * T, vs), labels.view(B * T), reduction='none')\n",
+    "l.mean(), F.cross_entropy(logits.view(B * T, vs), labels.view(B * T))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 203,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[-4.3894, -4.6075, -2.8143, -3.5261],\n",
+       "        [-4.3982, -3.5565, -2.3154, -4.5414],\n",
+       "        [-4.1661, -3.0418, -2.9297, -6.5461]])"
+      ]
+     },
+     "execution_count": 203,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "lnP = -F.cross_entropy(logits.transpose(-2, -1), labels, reduction='none')\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 221,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "lossi = F.cross_entropy(logits.transpose(-2, -1), labels, reduction='none')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 225,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(3.9027)"
+      ]
+     },
+     "execution_count": 225,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "lossi.mean(-1).mean(-1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "tokenizer"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 220,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(3.9027)"
+      ]
+     },
+     "execution_count": 220,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "(-r * lnP).mean()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 211,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([-15.3374, -14.8116, -16.6836])"
+      ]
+     },
+     "execution_count": 211,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "lnP.sum(-1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 213,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[0, 0, 0, 0, 1, 1, 1, 0, 1, 1],\n",
+       "        [0, 0, 1, 0, 1, 1, 0, 0, 1, 0],\n",
+       "        [0, 0, 1, 1, 1, 1, 1, 1, 1, 0],\n",
+       "        [0, 1, 1, 1, 1, 1, 0, 0, 0, 1],\n",
+       "        [1, 0, 1, 1, 1, 1, 1, 1, 0, 0],\n",
+       "        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],\n",
+       "        [1, 0, 0, 0, 0, 1, 1, 1, 0, 0],\n",
+       "        [0, 1, 1, 0, 0, 0, 0, 1, 1, 0],\n",
+       "        [0, 1, 1, 1, 1, 1, 0, 0, 1, 0],\n",
+       "        [1, 1, 1, 1, 0, 0, 0, 0, 1, 1],\n",
+       "        [0, 1, 0, 1, 1, 1, 1, 1, 0, 1],\n",
+       "        [0, 1, 0, 0, 1, 1, 0, 1, 0, 1],\n",
+       "        [0, 0, 1, 1, 0, 1, 0, 0, 0, 0],\n",
+       "        [0, 0, 1, 1, 1, 0, 0, 0, 1, 0],\n",
+       "        [0, 1, 1, 0, 0, 1, 0, 1, 1, 0],\n",
+       "        [0, 0, 0, 0, 0, 1, 1, 1, 1, 0],\n",
+       "        [1, 0, 0, 1, 0, 0, 1, 0, 1, 1],\n",
+       "        [1, 1, 1, 0, 0, 0, 0, 0, 0, 1],\n",
+       "        [1, 1, 1, 0, 1, 1, 1, 1, 0, 0],\n",
+       "        [0, 0, 0, 1, 1, 1, 1, 1, 0, 1]])"
+      ]
+     },
+     "execution_count": 213,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "torch.randint(2, (20, 10))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import torch"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 23,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "a = torch.tensor([0.5, 0.2, 0.1])\n",
+    "r = torch.tensor([1001., 1002., 1003.])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 28,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(tensor([500.5000, 200.4000, 100.3000]), tensor(208.2627))"
+      ]
+     },
+     "execution_count": 28,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "a * r, (a * r).std()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 29,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(tensor([0.0000, 0.2000, 0.2000]), tensor(0.1155))"
+      ]
+     },
+     "execution_count": 29,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "a * (r-1001), (a * (r-1001)).std()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 22,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([0., 1., 2.])"
+      ]
+     },
+     "execution_count": 22,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "r-1001"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 34,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(208.1666)"
+      ]
+     },
+     "execution_count": 34,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "(1000 * a).std()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 35,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.2082)"
+      ]
+     },
+     "execution_count": 35,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "(1 * a).std()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 38,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([1001.0000,  400.4000,  200.2000])"
+      ]
+     },
+     "execution_count": 38,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "1000 * a + 1002 * a"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 39,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([1.5000, 0.6000, 0.3000])"
+      ]
+     },
+     "execution_count": 39,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "a + 2 * a"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 64,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2, 20])"
+      ]
+     },
+     "execution_count": 64,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "g = torch.normal(1, 1, (2, 20))\n",
+    "g.shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 78,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "torch.Size([2, 1])"
+      ]
+     },
+     "execution_count": 78,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "r = torch.normal(1, 1, (2, 1))\n",
+    "r.shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 79,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[ 0.9881,  0.7762,  1.3812,  0.7958,  1.0094,  0.1070,  0.1709,  0.3961,\n",
+       "         -0.0118,  0.2604,  0.5242,  0.2400,  0.1989,  0.2737,  1.4879,  0.4091,\n",
+       "          0.7199, -0.2781,  0.7418,  1.2079],\n",
+       "        [ 3.4830,  0.1242,  0.8680,  0.0464,  0.8523,  1.0290, -0.3891,  2.5647,\n",
+       "         -0.1333,  3.3386,  0.1980,  0.2184,  0.4268,  4.9361,  3.0060,  1.3053,\n",
+       "          2.5407,  1.2318,  1.8263,  1.0207]])"
+      ]
+     },
+     "execution_count": 79,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "help(g * r)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 86,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Help on method norm in module torch._tensor:\n",
+      "\n",
+      "norm(p: Union[float, str, NoneType] = 'fro', dim=None, keepdim=False, dtype=None) method of torch.Tensor instance\n",
+      "    See :func:`torch.norm`\n",
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "help(g.norm)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 59,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(1.4053)"
+      ]
+     },
+     "execution_count": 59,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "(g * rr).std()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 88,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([39.9750, 37.7648])"
+      ]
+     },
+     "execution_count": 88,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "(g**2).sum(dim=-1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 179,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def norm_(g):\n",
+    "    k = g / g.norm(dim=-1, keepdim=True)\n",
+    "    return k"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import torch"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(99.9496)"
+      ]
+     },
+     "execution_count": 2,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "num = 10000\n",
+    "grad = torch.normal(0, 1, (num, 100))\n",
+    "g = torch.normal(100, 1, (num, 1))\n",
+    "(grad * g).std()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(1.0004)"
+      ]
+     },
+     "execution_count": 3,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "num = 10000\n",
+    "grad = torch.normal(0, 1, (num, 100))\n",
+    "g = torch.normal(100, 1, (num, 1)) - 100\n",
+    "(grad * g).std()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 261,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0.0996)"
+      ]
+     },
+     "execution_count": 261,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "(g * r).mean(-1).std()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 241,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[100.],\n",
+       "        [100.],\n",
+       "        [100.],\n",
+       "        [100.],\n",
+       "        [100.],\n",
+       "        [100.],\n",
+       "        [100.],\n",
+       "        [100.],\n",
+       "        [100.],\n",
+       "        [100.]])"
+      ]
+     },
+     "execution_count": 241,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "torch.normal(100, 0, (10, 1))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.5"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}

+ 0 - 0
ch12_rl/__init__.py


Fișier diff suprimat deoarece este prea mare
+ 220 - 0
ch12_rl/a2c.ipynb


+ 400 - 0
ch12_rl/intuition_model-Copy1.ipynb

@@ -0,0 +1,400 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7f8c38c68110>"
+      ]
+     },
+     "execution_count": 1,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import torch.nn.functional as F\n",
+    "from transformers import AutoTokenizer, GPT2LMHeadModel, GPT2Model\n",
+    "\n",
+    "\n",
+    "torch.manual_seed(12046)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "llm = GPT2LMHeadModel.from_pretrained('gpt2')\n",
+    "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class RewardModel(nn.Module):\n",
+    "\n",
+    "    def __init__(self, model):\n",
+    "        super().__init__()\n",
+    "        self.embedding = model\n",
+    "        self.score = nn.Linear(model.embed_dim, 1, bias=False)\n",
+    "\n",
+    "    def forward(self, x, seq_len=None):\n",
+    "        # x:表示文本,形状(B, T, vs)或者(B, T), seq_len:表示文本长度,形状(B)\n",
+    "        B = x.shape[0]\n",
+    "        T = x.shape[1]\n",
+    "        emb = self.get_last_hidden_state(x)     # (B, T, C)\n",
+    "        ind = torch.arange(B, device=x.device)\n",
+    "        if seq_len == None:\n",
+    "            seq_len = torch.tensor([T] * B)\n",
+    "        # 获取最后一个词元的特征\n",
+    "        pooled_emb = emb[ind, seq_len - 1]      # (B,    C)\n",
+    "        score = self.score(pooled_emb)          # (B,    1)\n",
+    "        return score\n",
+    "    \n",
+    "    def get_last_hidden_state(self, x):\n",
+    "        if len(x.shape) == 2:\n",
+    "            # x shape = (B, T)\n",
+    "            emb = self.embedding(x).last_hidden_state  # (B, T, C)\n",
+    "        # 为后面使用gumbel_softmax做准备,直接与embedding的模型参数进行计算\n",
+    "        else:\n",
+    "            # x shape = (B, T, vs)\n",
+    "            w = self.embedding.get_input_embeddings().weight  # (vs, C)\n",
+    "            inputs_embeds = x @ w  # (B, T, C)\n",
+    "            emb = self.embedding(inputs_embeds=inputs_embeds).last_hidden_state\n",
+    "        return emb\n",
+    "\n",
+    "r_model = RewardModel(GPT2Model.from_pretrained('gpt2'))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0., grad_fn=<MaxBackward1>)"
+      ]
+     },
+     "execution_count": 4,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "# 验证评分模型计算正确\n",
+    "x = torch.randint(0, tokenizer.vocab_size, (3, 4))\n",
+    "x_hot = F.one_hot(x, num_classes=tokenizer.vocab_size).float()\n",
+    "(r_model(x) - r_model(x_hot)).abs().max()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class RLModel(nn.Module):\n",
+    "    \n",
+    "    def __init__(self, llm, r_model):\n",
+    "        super().__init__()\n",
+    "        self.llm = llm\n",
+    "        self.r_model = r_model\n",
+    "        # 冻结模型\n",
+    "        for param in r_model.parameters():\n",
+    "            param.requires_grad = False\n",
+    "    \n",
+    "    def generate(self, idx, max_new_tokens):\n",
+    "        model = self.llm\n",
+    "        for _ in range(max_new_tokens):\n",
+    "            logits = model(input_ids=idx).logits\n",
+    "            logits = logits[:, -1, :]\n",
+    "            probs = F.softmax(logits, dim=-1)\n",
+    "            # 根据概率,随机生成下一个词元\n",
+    "            idx_next = torch.multinomial(probs, num_samples=1)\n",
+    "            idx = torch.cat((idx, idx_next), dim=1)\n",
+    "        return idx\n",
+    "    \n",
+    "    def forward(self, idx):\n",
+    "        # 为了代码简洁,我们设置产生文本的长度\n",
+    "        ans = self.generate(idx, 20)\n",
+    "        reward = self.r_model(ans)\n",
+    "        return reward"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "inputs = '1 + 2 = 3, 2 + 1 = 3, 1 + 2 ='\n",
+    "ids = tokenizer(inputs, return_tensors=\"pt\")\n",
+    "model = RLModel(llm, r_model)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "1 + 2 = 3, 2 + 1 = 3, 1 + 2 = 4, 3 + 1 = 5, 1 + 2 = 6 — Ha ha ha! In us\n"
+     ]
+    }
+   ],
+   "source": [
+    "# 验证generate是正确的\n",
+    "print(tokenizer.decode(model.generate(ids['input_ids'], 20)[0], skip_special_tokens=True))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
+      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "1 + 2 = 3, 2 + 1 = 3, 1 + 2 = 4 without action FARMADAM (same) Wooden child Servant use Intel SOCKS+\n"
+     ]
+    }
+   ],
+   "source": [
+    "res = model.llm.generate(\n",
+    "    input_ids=ids['input_ids'], max_new_tokens=20,\n",
+    "    do_sample=True, top_k=0)[0]\n",
+    "print(tokenizer.decode(res, skip_special_tokens=True))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "RuntimeError",
+     "evalue": "element 0 of tensors does not require grad and does not have a grad_fn",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-9-b7dbb844b37d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mids\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'input_ids'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;31m# 将报错\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    485\u001b[0m                 \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    486\u001b[0m             )\n\u001b[0;32m--> 487\u001b[0;31m         torch.autograd.backward(\n\u001b[0m\u001b[1;32m    488\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    489\u001b[0m         )\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    198\u001b[0m     \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    199\u001b[0m     \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m    201\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    202\u001b[0m         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass\n",
+      "\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn"
+     ]
+    }
+   ],
+   "source": [
+    "loss = -1 * model(ids['input_ids'])\n",
+    "# 将报错\n",
+    "loss.backward()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "tensor([ 928.,  926., 1631.,  340., 6175.])\n",
+      "tensor([ 996.,  865., 1616.,  314., 6209.])\n"
+     ]
+    }
+   ],
+   "source": [
+    "# 实验gumbel_softmax\n",
+    "logits = torch.randn(1, 5)\n",
+    "probs = F.softmax(logits, dim=-1)\n",
+    "y = torch.multinomial(probs, num_samples=10000, replacement=True)\n",
+    "print(torch.histogram(y.float(), bins=5).hist)\n",
+    "gumbel_y = torch.argmax(F.gumbel_softmax(logits.repeat(10000, 1), tau=1, hard=True), dim=-1, keepdim=True)\n",
+    "print(torch.histogram(gumbel_y.float(), bins=5).hist)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class RLModelWithGumbel(nn.Module):\n",
+    "    \n",
+    "    def __init__(self, llm, r_model):\n",
+    "        super().__init__()\n",
+    "        self.llm = llm\n",
+    "        self.r_model = r_model\n",
+    "        # 冻结模型\n",
+    "        for param in r_model.parameters():\n",
+    "            param.requires_grad = False\n",
+    "    \n",
+    "    def generate(self, idx, max_new_tokens):\n",
+    "        model = self.llm\n",
+    "        B, T = idx.shape\n",
+    "        ans = None\n",
+    "        for _ in range(max_new_tokens):\n",
+    "            logits = model(input_ids=idx).logits\n",
+    "            logits = logits[:, -1, :]\n",
+    "            # 根据概率,随机生成下一个词元\n",
+    "            idx_next_hot = F.gumbel_softmax(logits, tau=1, hard=True)  # (B, vs)\n",
+    "            idx_next = torch.argmax(idx_next_hot, dim=-1, keepdim=True)\n",
+    "            idx = torch.cat((idx, idx_next.long()), dim=1)\n",
+    "            idx_next_hot = idx_next_hot.unsqueeze(1)      # (B, 1, vs)\n",
+    "            if ans == None:\n",
+    "                ans = idx_next_hot\n",
+    "            else:\n",
+    "                ans = torch.cat((ans, idx_next_hot), dim=1)\n",
+    "        return idx, ans\n",
+    "    \n",
+    "    def forward(self, idx):\n",
+    "        # 为了代码简洁,我们设置产生文本的长度\n",
+    "        _, ans = self.generate(idx, 20)\n",
+    "        reward = self.r_model(ans)\n",
+    "        return reward"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "model_gumbel = RLModelWithGumbel(llm, r_model)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "tensor([[True, True, True, True, True, True, True, True, True, True, True, True,\n",
+      "         True, True, True, True, True, True, True, True]])\n",
+      "1 + 2 = 3, 2 + 1 = 3, 1 + 2 = 0, 1 + 1 = 0; extends laugh(cow, decision, discount) fifth person,\n"
+     ]
+    }
+   ],
+   "source": [
+    "# 验证generate正确\n",
+    "idx, ans = model_gumbel.generate(ids['input_ids'], 20)\n",
+    "print(idx[:, ids['input_ids'].shape[1]:] == torch.argmax(ans, dim=-1, keepdim=True).squeeze(-1))\n",
+    "print(tokenizer.decode(idx[0], skip_special_tokens=True))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(tensor([[-0.2085]]), tensor([[-0.2085]], grad_fn=<MmBackward0>))"
+      ]
+     },
+     "execution_count": 14,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "# 验证评分模型计算正确\n",
+    "model_gumbel.r_model(idx[:, ids['input_ids'].shape[1]:]), model_gumbel.r_model(ans)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[ 2.3994e-06,  4.8380e-06,  3.5403e-06,  ...,  4.4225e-06,\n",
+       "         -1.5709e-06,  4.8997e-06],\n",
+       "        [ 4.4208e-05,  1.3246e-04,  1.4072e-05,  ...,  7.9197e-05,\n",
+       "         -1.4321e-06, -6.9506e-06],\n",
+       "        [ 7.8832e-06,  5.7550e-06, -1.3545e-07,  ...,  5.6032e-06,\n",
+       "         -5.2948e-06,  1.6141e-06],\n",
+       "        ...,\n",
+       "        [ 6.0610e-10,  9.2871e-10,  3.8407e-10,  ...,  1.6127e-09,\n",
+       "         -1.6454e-09, -8.2414e-10],\n",
+       "        [-1.5970e-09,  4.7921e-09,  6.8945e-09,  ...,  7.0852e-09,\n",
+       "         -7.1524e-09, -1.9468e-09],\n",
+       "        [ 3.6735e-04,  2.7833e-04,  3.1601e-05,  ...,  1.5014e-05,\n",
+       "          3.1863e-04, -2.6312e-04]])"
+      ]
+     },
+     "execution_count": 15,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "loss = -1 * model_gumbel(ids['input_ids'])\n",
+    "# 成功运行反向传播算法\n",
+    "loss.backward()\n",
+    "list(model_gumbel.llm.parameters())[0].grad"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.5"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}

+ 393 - 0
ch12_rl/intuition_model.ipynb

@@ -0,0 +1,393 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7fb20dc5f110>"
+      ]
+     },
+     "execution_count": 1,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import torch.nn.functional as F\n",
+    "from transformers import AutoTokenizer, GPT2LMHeadModel, GPT2Model\n",
+    "\n",
+    "\n",
+    "torch.manual_seed(12046)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "llm = GPT2LMHeadModel.from_pretrained('gpt2')\n",
+    "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class RewardModel(nn.Module):\n",
+    "\n",
+    "    def __init__(self, model):\n",
+    "        super().__init__()\n",
+    "        self.embedding = model\n",
+    "        self.score = nn.Linear(model.embed_dim, 1, bias=False)\n",
+    "\n",
+    "    def forward(self, x, seq_len=None):\n",
+    "        # x:表示文本,形状(B, T, vs)或者(B, T), seq_len:表示文本长度,形状(B)\n",
+    "        B = x.shape[0]\n",
+    "        T = x.shape[1]\n",
+    "        emb = self.get_last_hidden_state(x)     # (B, T, C)\n",
+    "        ind = torch.arange(B, device=x.device)\n",
+    "        if seq_len == None:\n",
+    "            seq_len = torch.tensor([T] * B)\n",
+    "        # 获取最后一个词元的特征\n",
+    "        pooled_emb = emb[ind, seq_len - 1]      # (B,    C)\n",
+    "        score = self.score(pooled_emb)          # (B,    1)\n",
+    "        return score\n",
+    "    \n",
+    "    def get_last_hidden_state(self, x):\n",
+    "        if len(x.shape) == 2:\n",
+    "            # x shape = (B, T)\n",
+    "            emb = self.embedding(x).last_hidden_state  # (B, T, C)\n",
+    "        # 为后面使用gumbel_softmax做准备,直接与embedding的模型参数进行计算\n",
+    "        else:\n",
+    "            # x shape = (B, T, vs)\n",
+    "            w = self.embedding.get_input_embeddings().weight  # (vs, C)\n",
+    "            inputs_embeds = x @ w  # (B, T, C)\n",
+    "            emb = self.embedding(inputs_embeds=inputs_embeds).last_hidden_state\n",
+    "        return emb\n",
+    "\n",
+    "r_model = RewardModel(GPT2Model.from_pretrained('gpt2'))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor(0., grad_fn=<MaxBackward1>)"
+      ]
+     },
+     "execution_count": 4,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "# 验证评分模型计算正确\n",
+    "x = torch.randint(0, tokenizer.vocab_size, (3, 4))\n",
+    "x_hot = F.one_hot(x, num_classes=tokenizer.vocab_size).float()\n",
+    "(r_model(x) - r_model(x_hot)).abs().max()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class RLModel(nn.Module):\n",
+    "    \n",
+    "    def __init__(self, llm, r_model):\n",
+    "        super().__init__()\n",
+    "        self.llm = llm\n",
+    "        self.r_model = r_model\n",
+    "        # 冻结模型\n",
+    "        for param in r_model.parameters():\n",
+    "            param.requires_grad = False\n",
+    "    \n",
+    "    def generate(self, idx, max_new_tokens):\n",
+    "        model = self.llm\n",
+    "        for _ in range(max_new_tokens):\n",
+    "            logits = model(input_ids=idx).logits\n",
+    "            logits = logits[:, -1, :]\n",
+    "            probs = F.softmax(logits, dim=-1)\n",
+    "            # 根据概率,随机生成下一个词元\n",
+    "            idx_next = torch.multinomial(probs, num_samples=1)\n",
+    "            idx = torch.cat((idx, idx_next), dim=1)\n",
+    "        return idx\n",
+    "    \n",
+    "    def forward(self, idx):\n",
+    "        # 为了代码简洁,我们设置产生文本的长度\n",
+    "        ans = self.generate(idx, 20)\n",
+    "        reward = self.r_model(ans)\n",
+    "        return reward"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "inputs = '1 + 2 = 3, 2 + 1 = 3, 1 + 2 ='\n",
+    "ids = tokenizer(inputs, return_tensors=\"pt\")\n",
+    "model = RLModel(llm, r_model)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "1 + 2 = 3, 2 + 1 = 3, 1 + 2 = 4, 3 + 1 = 5, 1 + 2 = 6 — Ha ha ha! In us\n"
+     ]
+    }
+   ],
+   "source": [
+    "# 验证generate是正确的\n",
+    "print(tokenizer.decode(model.generate(ids['input_ids'], 20)[0], skip_special_tokens=True))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
+      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "1 + 2 = 3, 2 + 1 = 3, 1 + 2 = 4 without action FARMADAM (same) Wooden child Servant use Intel SOCKS+\n"
+     ]
+    }
+   ],
+   "source": [
+    "res = model.llm.generate(\n",
+    "    input_ids=ids['input_ids'], max_new_tokens=20,\n",
+    "    do_sample=True, top_k=0)[0]\n",
+    "print(tokenizer.decode(res, skip_special_tokens=True))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "RuntimeError",
+     "evalue": "element 0 of tensors does not require grad and does not have a grad_fn",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-9-b7dbb844b37d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mids\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'input_ids'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;31m# 将报错\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    485\u001b[0m                 \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    486\u001b[0m             )\n\u001b[0;32m--> 487\u001b[0;31m         torch.autograd.backward(\n\u001b[0m\u001b[1;32m    488\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    489\u001b[0m         )\n",
+      "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    198\u001b[0m     \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    199\u001b[0m     \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m    201\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    202\u001b[0m         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass\n",
+      "\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn"
+     ]
+    }
+   ],
+   "source": [
+    "loss = -1 * model(ids['input_ids'])\n",
+    "# 将报错\n",
+    "loss.backward()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "tensor([ 928.,  926., 1631.,  340., 6175.])\n",
+      "tensor([ 996.,  865., 1616.,  314., 6209.])\n"
+     ]
+    }
+   ],
+   "source": [
+    "# 实验gumbel_softmax\n",
+    "logits = torch.randn(1, 5)\n",
+    "probs = F.softmax(logits, dim=-1)\n",
+    "y = torch.multinomial(probs, num_samples=10000, replacement=True)\n",
+    "print(torch.histogram(y.float(), bins=5).hist)\n",
+    "gumbel_y = torch.argmax(F.gumbel_softmax(logits.repeat(10000, 1), tau=1, hard=True), dim=-1, keepdim=True)\n",
+    "print(torch.histogram(gumbel_y.float(), bins=5).hist)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class RLModelWithGumbel(nn.Module):\n",
+    "    \n",
+    "    def __init__(self, llm, r_model):\n",
+    "        super().__init__()\n",
+    "        self.llm = llm\n",
+    "        self.r_model = r_model\n",
+    "        # 冻结模型\n",
+    "        for param in r_model.parameters():\n",
+    "            param.requires_grad = False\n",
+    "    \n",
+    "    def generate(self, idx, max_new_tokens):\n",
+    "        model = self.llm\n",
+    "        B, T = idx.shape\n",
+    "        ans = None\n",
+    "        for _ in range(max_new_tokens):\n",
+    "            logits = model(input_ids=idx).logits\n",
+    "            logits = logits[:, -1, :]\n",
+    "            # 根据概率,随机生成下一个词元\n",
+    "            idx_next_hot = F.gumbel_softmax(logits, tau=1, hard=True)  # (B, vs)\n",
+    "            idx_next = torch.argmax(idx_next_hot, dim=-1, keepdim=True)\n",
+    "            idx = torch.cat((idx, idx_next.long()), dim=1)\n",
+    "            idx_next_hot = idx_next_hot.unsqueeze(1)      # (B, 1, vs)\n",
+    "            if ans == None:\n",
+    "                ans = idx_next_hot\n",
+    "            else:\n",
+    "                ans = torch.cat((ans, idx_next_hot), dim=1)\n",
+    "        return idx, ans\n",
+    "    \n",
+    "    def forward(self, idx):\n",
+    "        # 为了代码简洁,我们设置产生文本的长度\n",
+    "        _, ans = self.generate(idx, 20)\n",
+    "        reward = self.r_model(ans)\n",
+    "        return reward"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "model_gumbel = RLModelWithGumbel(llm, r_model)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "tensor([[True, True, True, True, True, True, True, True, True, True, True, True,\n",
+      "         True, True, True, True, True, True, True, True]])\n",
+      "1 + 2 = 3, 2 + 1 = 3, 1 + 2 = 0, 1 + 1 = 0; extends laugh(cow, decision, discount) fifth person,\n"
+     ]
+    }
+   ],
+   "source": [
+    "# 验证generate正确\n",
+    "idx, ans = model_gumbel.generate(ids['input_ids'], 20)\n",
+    "print(idx[:, ids['input_ids'].shape[1]:] == torch.argmax(ans, dim=-1, keepdim=True).squeeze(-1))\n",
+    "print(tokenizer.decode(idx[0], skip_special_tokens=True))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(tensor([[-0.2085]]), tensor([[-0.2085]], grad_fn=<MmBackward0>))"
+      ]
+     },
+     "execution_count": 14,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "# 验证评分模型计算正确\n",
+    "model_gumbel.r_model(idx[:, ids['input_ids'].shape[1]:]), model_gumbel.r_model(ans)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "tensor([[ 2.3994e-06,  4.8380e-06,  3.5403e-06,  ...,  4.4225e-06,\n",
+       "         -1.5709e-06,  4.8997e-06],\n",
+       "        [ 4.4208e-05,  1.3246e-04,  1.4072e-05,  ...,  7.9197e-05,\n",
+       "         -1.4321e-06, -6.9506e-06],\n",
+       "        [ 7.8832e-06,  5.7550e-06, -1.3545e-07,  ...,  5.6032e-06,\n",
+       "         -5.2948e-06,  1.6141e-06],\n",
+       "        ...,\n",
+       "        [ 6.0610e-10,  9.2871e-10,  3.8407e-10,  ...,  1.6127e-09,\n",
+       "         -1.6454e-09, -8.2414e-10],\n",
+       "        [-1.5970e-09,  4.7921e-09,  6.8945e-09,  ...,  7.0852e-09,\n",
+       "         -7.1524e-09, -1.9468e-09],\n",
+       "        [ 3.6735e-04,  2.7833e-04,  3.1601e-05,  ...,  1.5014e-05,\n",
+       "          3.1863e-04, -2.6312e-04]])"
+      ]
+     },
+     "execution_count": 15,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "loss = -1 * model_gumbel(ids['input_ids'])\n",
+    "# 成功运行反向传播算法\n",
+    "loss.backward()\n",
+    "list(model_gumbel.llm.parameters())[0].grad"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.5"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}

+ 161 - 0
ch12_rl/llm_ppo.ipynb

@@ -0,0 +1,161 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "<torch._C.Generator at 0x7fc57cc63110>"
+      ]
+     },
+     "execution_count": 2,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import torch.nn.functional as F\n",
+    "from transformers import AutoModel, AutoTokenizer, GPT2LMHeadModel, GPT2Model\n",
+    "\n",
+    "\n",
+    "torch.manual_seed(12046)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learning_rate = 6e-4\n",
+    "sequence_len = 1024\n",
+    "batch_size = 8\n",
+    "gra_acc_steps = 8 * 2\n",
+    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
+    "eval_iters = 64 * 2\n",
+    "eval_interval = 50"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "llm = GPT2LMHeadModel.from_pretrained('gpt2')\n",
+    "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "GPT2LMHeadModel(\n",
+       "  (transformer): GPT2Model(\n",
+       "    (wte): Embedding(50257, 768)\n",
+       "    (wpe): Embedding(1024, 768)\n",
+       "    (drop): Dropout(p=0.1, inplace=False)\n",
+       "    (h): ModuleList(\n",
+       "      (0-11): 12 x GPT2Block(\n",
+       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
+       "        (attn): GPT2Attention(\n",
+       "          (c_attn): Conv1D()\n",
+       "          (c_proj): Conv1D()\n",
+       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
+       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
+       "        )\n",
+       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
+       "        (mlp): GPT2MLP(\n",
+       "          (c_fc): Conv1D()\n",
+       "          (c_proj): Conv1D()\n",
+       "          (act): NewGELUActivation()\n",
+       "          (dropout): Dropout(p=0.1, inplace=False)\n",
+       "        )\n",
+       "      )\n",
+       "    )\n",
+       "    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
+       "  )\n",
+       "  (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
+       ")"
+      ]
+     },
+     "execution_count": 11,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "class RewardModel(nn.Module):\n",
+    "\n",
+    "    def __init__(self, model):\n",
+    "        super().__init__()\n",
+    "        self.embedding = model\n",
+    "        self.score = nn.Linear(model.embed_dim, 1, bias=False)\n",
+    "\n",
+    "    def forward(self, x, seq_len=None):\n",
+    "        # x:表示文本,形状(B, T, vs)或者(B, T), seq_len:表示文本长度,形状(B)\n",
+    "        B = x.shape[0]\n",
+    "        T = x.shape[1]\n",
+    "        emb = self.get_last_hidden_state(x)     # (B, T, C)\n",
+    "        ind = torch.arange(B, device=x.device)\n",
+    "        if seq_len == None:\n",
+    "            seq_len = torch.tensor([T] * B)\n",
+    "        # 获取最后一个词元的特征\n",
+    "        pooled_emb = emb[ind, seq_len - 1]      # (B,    C)\n",
+    "        score = self.score(pooled_emb)          # (B,    1)\n",
+    "        return score\n",
+    "    \n",
+    "    def get_last_hidden_state(self, x):\n",
+    "        if len(x.shape) == 2:\n",
+    "            # x shape = (B, T)\n",
+    "            emb = self.embedding(x).last_hidden_state  # (B, T, C)\n",
+    "        # 为后面使用gumbel_softmax做准备,直接与embedding的模型参数进行计算\n",
+    "        else:\n",
+    "            # x shape = (B, T, vs)\n",
+    "            w = self.embedding.get_input_embeddings().weight  # (vs, C)\n",
+    "            inputs_embeds = x @ w  # (B, T, C)\n",
+    "            emb = self.embedding(inputs_embeds=inputs_embeds).last_hidden_state\n",
+    "        return emb\n",
+    "\n",
+    "r_model = RewardModel(GPT2Model.from_pretrained('gpt2'))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.5"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}

Fișier diff suprimat deoarece este prea mare
+ 156 - 0
ch12_rl/policy_learning.ipynb


+ 69 - 0
ch12_rl/utils.py

@@ -0,0 +1,69 @@
+# -*- coding: UTF-8 -*-
+"""
+此脚本用于定义游戏以及相应的可视化工具
+"""
+
+
+import matplotlib.pyplot as plt
+import torch
+import pandas as pd
+
+
+class Lottery:
+    
+    def __init__(self):
+        self.params = {
+            'w': (1, 1),
+            'l': (-1, 1)
+        }
+    
+    def reset(self):
+        self.state = 'w' if torch.randn(1).item() > 0 else 'l'
+        return self.state
+        
+    def step(self, action):
+        # 如果状态是t,则终止游戏
+        if self.state == 't':
+            return self.state, 0
+        # 1表示抽奖; 0表示终止
+        center, std = self.params[self.state]
+        if action == 0:
+            self.state = 't'
+            return 't', 0
+        else:
+            reward = torch.normal(center, std, (1,)).item()
+        # 有10%的概率终止游戏
+        if torch.rand(1).item() < 0.01:
+            self.state = 't'
+        return self.state, reward
+
+
+def plot_values(v):
+    # 为在Matplotlib中显示中文,设置特殊字体
+    plt.rcParams["font.sans-serif"] = ["SimHei"]
+    # 正确显示负号
+    plt.rcParams['axes.unicode_minus'] = False
+    plt.rcParams.update({'font.size': 13})
+    # 创建一个图形框
+    fig = plt.figure(figsize=(6, 6), dpi=100)
+    v = pd.DataFrame(v)
+    for k in v:
+        v[k].plot(label=k, legend=True)
+    legend = plt.legend(shadow=True, loc="best", fontsize=20)
+    plt.yticks(range(-10, 11, 4))
+    return fig
+
+
+def plot_action_probs(v):
+    # 为在Matplotlib中显示中文,设置特殊字体
+    plt.rcParams["font.sans-serif"] = ["SimHei"]
+    # 正确显示负号
+    plt.rcParams['axes.unicode_minus'] = False
+    plt.rcParams.update({'font.size': 13})
+    # 创建一个图形框
+    fig = plt.figure(figsize=(6, 6), dpi=100)
+    v = pd.DataFrame(v)
+    for k in v:
+        v[k].apply(lambda x: x[1]).plot(label=k, legend=True)
+    legend = plt.legend(shadow=True, loc="best", fontsize=20)
+    return fig

Fișier diff suprimat deoarece este prea mare
+ 105 - 0
ch12_rl/value_learning.ipynb


Unele fișiere nu au fost afișate deoarece prea multe fișiere au fost modificate în acest diff