{ "cells": [ { "cell_type": "code", "execution_count": 115, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n", "\u001b[0mCollecting trl\n", " Using cached trl-0.7.4-py3-none-any.whl.metadata (10 kB)\n", "Requirement already satisfied: torch>=1.4.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (2.0.1)\n", "Requirement already satisfied: transformers>=4.18.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (4.31.0)\n", "Requirement already satisfied: numpy>=1.18.2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (1.24.4)\n", "Requirement already satisfied: accelerate in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (0.21.0)\n", "Requirement already satisfied: datasets in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (2.14.2)\n", "Requirement already satisfied: tyro>=0.5.11 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (0.5.12)\n", "Requirement already satisfied: filelock in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch>=1.4.0->trl) (3.0.12)\n", "Requirement already satisfied: typing-extensions in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch>=1.4.0->trl) (4.8.0)\n", "Requirement already satisfied: sympy in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch>=1.4.0->trl) (1.6.2)\n", "Requirement already satisfied: networkx in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch>=1.4.0->trl) (2.5)\n", "Requirement already satisfied: jinja2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch>=1.4.0->trl) (2.11.2)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (0.16.4)\n", "Requirement already satisfied: packaging>=20.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (23.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (5.3.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (2020.10.15)\n", "Requirement already satisfied: requests in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (2.24.0)\n", "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (0.13.3)\n", "Requirement already satisfied: safetensors>=0.3.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (0.3.1)\n", "Requirement already satisfied: tqdm>=4.27 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (4.65.0)\n", "Requirement already satisfied: docstring-parser>=0.14.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from tyro>=0.5.11->trl) (0.15)\n", "Requirement already satisfied: rich>=11.1.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from tyro>=0.5.11->trl) (13.6.0)\n", "Requirement already satisfied: shtab>=1.5.6 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from tyro>=0.5.11->trl) (1.6.4)\n", "Requirement already satisfied: psutil in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from accelerate->trl) (5.7.2)\n", "Requirement already satisfied: pyarrow>=8.0.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (12.0.1)\n", "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (0.3.7)\n", "Requirement already satisfied: pandas in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (2.0.3)\n", "Requirement already satisfied: xxhash in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (3.3.0)\n", "Requirement already satisfied: multiprocess in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (0.70.15)\n", "Requirement already satisfied: fsspec>=2021.11.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from fsspec[http]>=2021.11.1->datasets->trl) (2023.6.0)\n", "Requirement already satisfied: aiohttp in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (3.8.5)\n", "Requirement already satisfied: attrs>=17.3.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (23.1.0)\n", "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (3.2.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (6.0.4)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (4.0.2)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (1.9.2)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (1.4.0)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (1.3.1)\n", "Requirement already satisfied: chardet<4,>=3.0.2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers>=4.18.0->trl) (3.0.4)\n", "Requirement already satisfied: idna<3,>=2.5 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers>=4.18.0->trl) (2.10)\n", "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers>=4.18.0->trl) (1.25.11)\n", "Requirement already satisfied: certifi>=2017.4.17 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers>=4.18.0->trl) (2020.6.20)\n", "Requirement already satisfied: markdown-it-py>=2.2.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (3.0.0)\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (2.16.1)\n", "Requirement already satisfied: MarkupSafe>=0.23 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from jinja2->torch>=1.4.0->trl) (1.1.1)\n", "Requirement already satisfied: decorator>=4.3.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from networkx->torch>=1.4.0->trl) (4.4.2)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from pandas->datasets->trl) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from pandas->datasets->trl) (2020.1)\n", "Requirement already satisfied: tzdata>=2022.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from pandas->datasets->trl) (2023.3)\n", "Requirement already satisfied: mpmath>=0.19 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from sympy->torch>=1.4.0->trl) (1.1.0)\n", "Requirement already satisfied: mdurl~=0.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro>=0.5.11->trl) (0.1.2)\n", "Requirement already satisfied: six>=1.5 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from python-dateutil>=2.8.2->pandas->datasets->trl) (1.15.0)\n", "Using cached trl-0.7.4-py3-none-any.whl (133 kB)\n", "\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33mDEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063\u001b[0m\u001b[33m\n", "\u001b[0mInstalling collected packages: trl\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Successfully installed trl-0.7.4\r\n" ] } ], "source": [ "!pip install --user trl" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "a = torch.randn(1, 4, requires_grad=True)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.2511, -0.2847, -1.3987, 0.9078]], requires_grad=True)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "s = torch.argmax(a)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(3)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "s" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 0., 10., 3., 0.], requires_grad=True)" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "weights = torch.tensor([0, 10, 3, 0], dtype=torch.float, requires_grad=True) # create a tensor of weights\n", "weights" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "s = torch.multinomial(weights, 1)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([2])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "s" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "ename": "RuntimeError", "evalue": "element 0 of tensors does not require grad and does not have a grad_fn", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0ms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 485\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 486\u001b[0m )\n\u001b[0;32m--> 487\u001b[0;31m torch.autograd.backward(\n\u001b[0m\u001b[1;32m 488\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 489\u001b[0m )\n", "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m 201\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 202\u001b[0m allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass\n", "\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn" ] } ], "source": [ "s.backward()" ] }, { "cell_type": "code", "execution_count": 226, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from transformers import AutoTokenizer, GPT2LMHeadModel\n", "\n", "\n", "torch.manual_seed(12046)\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")" ] }, { "cell_type": "code", "execution_count": 227, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "50257" ] }, "execution_count": 227, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.vocab_size" ] }, { "cell_type": "code", "execution_count": 231, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.027730370438185e+47" ] }, "execution_count": 231, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import math\n", "\n", "math.pow(50256, 10)" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", "Input length of input_ids is 7, but `max_length` is set to 1. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.\n" ] } ], "source": [ "res = model.generate(**ids, max_length=1)" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[2061, 318, 262, 3139, 286, 2807, 30, 198]])" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "l = model(**ids).logits" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[ -37.2172, -36.8864, -40.3563, ..., -43.4351, -43.0232,\n", " -37.0993],\n", " [ -84.7418, -82.6453, -88.2857, ..., -89.4501, -90.0555,\n", " -84.6770],\n", " [ -94.2372, -92.2747, -95.0347, ..., -94.9362, -97.9820,\n", " -91.8550],\n", " ...,\n", " [ -75.3586, -73.5543, -77.7737, ..., -80.5165, -81.9881,\n", " -76.4165],\n", " [ -77.2812, -75.8471, -80.8869, ..., -88.2672, -86.0258,\n", " -79.5727],\n", " [-123.9086, -123.2837, -124.5704, ..., -135.1090, -134.9653,\n", " -119.7716]]], grad_fn=)" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "l" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[ -37.2172, -36.8864, -40.3563, ..., -43.4351, -43.0232,\n", " -37.0993],\n", " [ -84.7418, -82.6453, -88.2857, ..., -89.4501, -90.0555,\n", " -84.6770],\n", " [ -94.2372, -92.2747, -95.0347, ..., -94.9362, -97.9820,\n", " -91.8550],\n", " ...,\n", " [ -75.3586, -73.5543, -77.7737, ..., -80.5165, -81.9881,\n", " -76.4165],\n", " [ -77.2812, -75.8471, -80.8869, ..., -88.2672, -86.0258,\n", " -79.5727],\n", " [-123.9086, -123.2837, -124.5704, ..., -135.1090, -134.9653,\n", " -119.7716]]])" ] }, "execution_count": 60, "metadata": {}, "output_type": "execute_result" } ], "source": [ "with torch.no_grad():\n", " ll = model(**ids).logits\n", "ll" ] }, { "cell_type": "code", "execution_count": 87, "metadata": {}, "outputs": [], "source": [ "weights = torch.tensor([0, 10, 3, 0], dtype=torch.float, requires_grad=True)" ] }, { "cell_type": "code", "execution_count": 88, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([1])" ] }, "execution_count": 88, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y = torch.multinomial(weights, 1)\n", "y" ] }, { "cell_type": "code", "execution_count": 89, "metadata": {}, "outputs": [ { "ename": "RuntimeError", "evalue": "element 0 of tensors does not require grad and does not have a grad_fn", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 485\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 486\u001b[0m )\n\u001b[0;32m--> 487\u001b[0;31m torch.autograd.backward(\n\u001b[0m\u001b[1;32m 488\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 489\u001b[0m )\n", "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m 201\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 202\u001b[0m allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass\n", "\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn" ] } ], "source": [ "y.backward()" ] }, { "cell_type": "code", "execution_count": 91, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ ":1: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", " F.softmax(weights)[y]\n" ] }, { "data": { "text/plain": [ "tensor([0.9990], grad_fn=)" ] }, "execution_count": 91, "metadata": {}, "output_type": "execute_result" } ], "source": [ "F.softmax(weights)[y]" ] }, { "cell_type": "code", "execution_count": 82, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([2])" ] }, "execution_count": 82, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [], "source": [ "k = (weights == weights[y]).nonzero()" ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "scrolled": true }, "outputs": [ { "ename": "RuntimeError", "evalue": "only Tensors of floating point and complex dtype can require gradients", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mids\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'input_ids'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_grad\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m: only Tensors of floating point and complex dtype can require gradients" ] } ], "source": [ "ids['input_ids'].requires_grad=True" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Tensor" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Tensor" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(a)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(1.0545, grad_fn=)" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = torch.rand(1, 4)\n", "a.requires_grad=True\n", "y = torch.sum(a)\n", "y" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [], "source": [ "torch.sum(a).backward()" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(8903)" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y = torch.sum(ids['input_ids'])\n", "y" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "ename": "RuntimeError", "evalue": "element 0 of tensors does not require grad and does not have a grad_fn", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 485\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 486\u001b[0m )\n\u001b[0;32m--> 487\u001b[0;31m torch.autograd.backward(\n\u001b[0m\u001b[1;32m 488\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 489\u001b[0m )\n", "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m 201\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 202\u001b[0m allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass\n", "\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn" ] } ], "source": [ "y.backward()" ] }, { "cell_type": "code", "execution_count": 94, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(3, 2, requires_grad=True)\n", "y = torch.ones(3, 2)\n", "x\n", "z = torch.where(x > 0, 1.0, 0.0)" ] }, { "cell_type": "code", "execution_count": 95, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 1.],\n", " [1., 1.],\n", " [0., 0.]])" ] }, "execution_count": 95, "metadata": {}, "output_type": "execute_result" } ], "source": [ "z" ] }, { "cell_type": "code", "execution_count": 103, "metadata": {}, "outputs": [], "source": [ "t = (x - 0 > 0.001) * 1 + 0" ] }, { "cell_type": "code", "execution_count": 100, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-1.4747, 0.2449],\n", " [ 0.8045, 0.4111],\n", " [-0.7769, -0.4431]], requires_grad=True)" ] }, "execution_count": 100, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x" ] }, { "cell_type": "code", "execution_count": 104, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0, 1],\n", " [1, 1],\n", " [0, 0]])" ] }, "execution_count": 104, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t" ] }, { "cell_type": "code", "execution_count": 106, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[False, True],\n", " [ True, True],\n", " [False, False]])" ] }, "execution_count": 106, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(x - 0 > 0.001)" ] }, { "cell_type": "code", "execution_count": 108, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-1.4747, 0.2449],\n", " [ 0.8045, 0.4111],\n", " [-0.7769, -0.4431]], requires_grad=True)" ] }, "execution_count": 108, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x" ] }, { "cell_type": "code", "execution_count": 109, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.8045, grad_fn=)" ] }, "execution_count": 109, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.max(x)" ] }, { "cell_type": "code", "execution_count": 117, "metadata": {}, "outputs": [ { "ename": "ModuleNotFoundError", "evalue": "No module named 'trl'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtrl\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mAutoModelForCausalLMWithValueHea\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'trl'" ] } ], "source": [ "from trl import AutoModelForCausalLMWithValueHea" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/tgbaggio/.local/lib/python3.8/site-packages/trl/trainer/ppo_config.py:141: UserWarning: The `optimize_cuda_cache` arguement will be deprecated soon, please use `optimize_device_cache` instead.\n", " warnings.warn(\n" ] } ], "source": [ "import trl" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from trl import AutoModelForCausalLMWithValueHead" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c8eb15f8baa64dbfa6a41e4349bc1068", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading (…)lve/main/config.json: 0%| | 0.00/577 [00:00\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mValueError\u001b[0m: step must be greater than zero" ] } ], "source": [ "x[::-1]" ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [], "source": [ "self_config_gamma = 0.8\n", "self_config_lam = 0.9\n", "\n", "def compute_advantages(\n", " values: torch.FloatTensor,\n", " rewards: torch.FloatTensor):\n", " lastgaelam = 0\n", " advantages_reversed = []\n", " gen_len = rewards.shape[-1]\n", " for t in reversed(range(gen_len)):\n", " nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0\n", " delta = rewards[:, t] + self_config_gamma * nextvalues - values[:, t]\n", " lastgaelam = delta + self_config_gamma * self_config_lam * lastgaelam\n", " advantages_reversed.append(lastgaelam)\n", " advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)\n", "\n", " returns = advantages + values\n", " advantages = advantages.detach()\n", " return values, advantages, returns\n" ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [], "source": [ "values = torch.tensor([[3., 4., 2.]])\n", "rewards = torch.tensor([[-1., 2., 2.]])" ] }, { "cell_type": "code", "execution_count": 81, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[3., 4., 2.]]),\n", " tensor([[-1.0880, -0.4000, 0.0000]]),\n", " tensor([[1.9120, 3.6000, 2.0000]]))" ] }, "execution_count": 81, "metadata": {}, "output_type": "execute_result" } ], "source": [ "compute_advantages(values, rewards)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.8800000000000003" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "0.8 * 2 + 0.8 ** 2 * 2 - 1" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import gym\n", "import pygame\n", "\n", "env = gym.make('MountainCar-v0', render_mode=\"human\")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The observation space: Box([-1.2 -0.07], [0.6 0.07], (2,), float32)\n", "The action space: Discrete(3)\n" ] } ], "source": [ "# Observation and action space \n", "obs_space = env.observation_space\n", "action_space = env.action_space\n", "print(\"The observation space: {}\".format(obs_space))\n", "print(\"The action space: {}\".format(action_space))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The initial observation is (array([-0.46305716, 0. ], dtype=float32), {})\n", "0\n", "The new observation is [-0.46450874 -0.00145157]\n" ] } ], "source": [ "import matplotlib.pyplot as plt \n", "\n", "# reset the environment and see the initial observation\n", "obs = env.reset()\n", "print(\"The initial observation is {}\".format(obs))\n", "\n", "# Sample a random action from the entire action space\n", "random_action = env.action_space.sample()\n", "print(random_action)\n", "\n", "# # Take the action and get the new observation space\n", "new_obs, reward, done, _, info = env.step(random_action)\n", "env.render() \n", "pygame.display.update()\n", "print(\"The new observation is {}\".format(new_obs))" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2\n", "The new observation is [-5.2235210e-01 -4.3133463e-04]\n" ] } ], "source": [ "# Sample a random action from the entire action space\n", "random_action = env.action_space.sample()\n", "print(random_action)\n", "\n", "# # Take the action and get the new observation space\n", "new_obs, reward, done, _, info = env.step(random_action)\n", "print(\"The new observation is {}\".format(new_obs))" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "env.close()\n" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pygame\n", "\n", "pygame.event.get" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "ename": "DependencyNotInstalled", "evalue": "Box2D is not installed, run `pip install gymnasium[box2d]`", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/gymnasium/envs/box2d/bipedal_walker.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mBox2D\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m from Box2D.b2 import (\n", "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'Box2D'", "\nThe above exception was the direct cause of the following exception:\n", "\u001b[0;31mDependencyNotInstalled\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mgymnasium\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"LunarLander-v2\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrender_mode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"human\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mobservation\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minfo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/gymnasium/envs/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(id, max_episode_steps, autoreset, apply_api_compatibility, disable_env_checker, **kwargs)\u001b[0m\n\u001b[1;32m 754\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 755\u001b[0m \u001b[0;31m# Assume it's a string\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 756\u001b[0;31m \u001b[0menv_creator\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_env_creator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv_spec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mentry_point\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 757\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 758\u001b[0m \u001b[0;31m# Determine if to use the rendering\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/gymnasium/envs/registration.py\u001b[0m in \u001b[0;36mload_env_creator\u001b[0;34m(name)\u001b[0m\n\u001b[1;32m 543\u001b[0m \"\"\"\n\u001b[1;32m 544\u001b[0m \u001b[0mmod_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattr_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\":\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 545\u001b[0;31m \u001b[0mmod\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimportlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimport_module\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmod_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 546\u001b[0m \u001b[0mfn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattr_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 547\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/__init__.py\u001b[0m in \u001b[0;36mimport_module\u001b[0;34m(name, package)\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[0mlevel\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 127\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_bootstrap\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_gcd_import\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlevel\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpackage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlevel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 128\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_gcd_import\u001b[0;34m(name, package, level)\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load\u001b[0;34m(name, import_)\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load_unlocked\u001b[0;34m(name, import_)\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_call_with_frames_removed\u001b[0;34m(f, *args, **kwds)\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_gcd_import\u001b[0;34m(name, package, level)\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load\u001b[0;34m(name, import_)\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load_unlocked\u001b[0;34m(name, import_)\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_load_unlocked\u001b[0;34m(spec)\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap_external.py\u001b[0m in \u001b[0;36mexec_module\u001b[0;34m(self, module)\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_call_with_frames_removed\u001b[0;34m(f, *args, **kwds)\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/gymnasium/envs/box2d/__init__.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mgymnasium\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbox2d\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbipedal_walker\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mBipedalWalker\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mBipedalWalkerHardcore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgymnasium\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbox2d\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcar_racing\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCarRacing\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgymnasium\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbox2d\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlunar_lander\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mLunarLander\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLunarLanderContinuous\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/gymnasium/envs/box2d/bipedal_walker.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 23\u001b[0m )\n\u001b[1;32m 24\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m raise DependencyNotInstalled(\n\u001b[0m\u001b[1;32m 26\u001b[0m \u001b[0;34m\"Box2D is not installed, run `pip install gymnasium[box2d]`\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m ) from e\n", "\u001b[0;31mDependencyNotInstalled\u001b[0m: Box2D is not installed, run `pip install gymnasium[box2d]`" ] } ], "source": [ "import gymnasium as gym\n", "\n", "env = gym.make(\"LunarLander-v2\", render_mode=\"human\")\n", "observation, info = env.reset()\n", "\n", "for _ in range(1000):\n", " action = env.action_space.sample() # agent policy that uses the observation and info\n", " observation, reward, terminated, truncated, info = env.step(action)\n", "\n", " if terminated or truncated:\n", " observation, info = env.reset()\n", "\n", "env.close()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import gymnasium as gym\n", "import math\n", "import random\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "env = gym.make(\"CartPole-v1\")\n", "\n", "# set up matplotlib\n", "is_ipython = 'inline' in matplotlib.get_backend()\n", "if is_ipython:\n", " from IPython import display\n", "\n", "plt.ion()" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "import gymnasium as gym\n", "import pygame\n", "\n", "env = gym.make(\"CartPole-v1\", render_mode=\"human\")\n", "observation, info = env.reset(seed=42)\n", "#for _ in range(100):\n", "# action = env.action_space.sample() # this is where you would insert your policy\n", "# observation, reward, terminated, truncated, info = env.step(action)\n", "# if terminated or truncated:\n", "# break\n", "#env.close()\n", "\n", "#pygame.display.update()" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([ 0.20159529, 1.9464185 , -0.22034578, -2.9908078 ], dtype=float32),\n", " 1.0,\n", " True,\n", " False,\n", " {})" ] }, "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ "action = env.action_space.sample() # this is where you would insert your policy\n", "env.step(1)" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [], "source": [ "env.close()" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.075" ] }, "execution_count": 64, "metadata": {}, "output_type": "execute_result" } ], "source": [ "0.3/4" ] }, { "cell_type": "code", "execution_count": 163, "metadata": {}, "outputs": [], "source": [ "m = [\n", " [0.775, 0.075, 0.075, 0.075],\n", " [0.075, 0.775, 0.075, 0.075],\n", " [0.075, 0.075, 0.775, 0.075],\n", " [0.075, 0.075, 0.075, 0.775]\n", "]" ] }, { "cell_type": "code", "execution_count": 164, "metadata": {}, "outputs": [], "source": [ "import torch\n", "m = torch.tensor(m)" ] }, { "cell_type": "code", "execution_count": 95, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[1.],\n", " [1.],\n", " [1.],\n", " [1.]])" ] }, "execution_count": 95, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m @ torch.ones(4, 1)" ] }, { "cell_type": "code", "execution_count": 192, "metadata": {}, "outputs": [], "source": [ "r = torch.tensor([[2.0, 1.0, -1.0, -2.0]]).T\n", "v = torch.tensor([[200.0, 100.0, -100.0, -200.0]]).T\n", "gamma = 0.9" ] }, { "cell_type": "code", "execution_count": 198, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[ 5.4054],\n", " [ 2.7027],\n", " [-2.7027],\n", " [-5.4054]])\n", "tensor([[ 5.4054],\n", " [ 2.7027],\n", " [-2.7027],\n", " [-5.4054]])\n", "tensor([[ 5.4054],\n", " [ 2.7027],\n", " [-2.7027],\n", " [-5.4054]])\n", "tensor([[ 5.4054],\n", " [ 2.7027],\n", " [-2.7027],\n", " [-5.4054]])\n", "tensor([[ 5.4054],\n", " [ 2.7027],\n", " [-2.7027],\n", " [-5.4054]])\n", "tensor([[ 5.4054],\n", " [ 2.7027],\n", " [-2.7027],\n", " [-5.4054]])\n", "tensor([[ 5.4054],\n", " [ 2.7027],\n", " [-2.7027],\n", " [-5.4054]])\n", "tensor([[ 5.4054],\n", " [ 2.7027],\n", " [-2.7027],\n", " [-5.4054]])\n", "tensor([[ 5.4054],\n", " [ 2.7027],\n", " [-2.7027],\n", " [-5.4054]])\n", "tensor([[ 5.4054],\n", " [ 2.7027],\n", " [-2.7027],\n", " [-5.4054]])\n" ] } ], "source": [ "for i in range(10):\n", " print(v)\n", " v = gamma * m @ v + r" ] }, { "cell_type": "code", "execution_count": 142, "metadata": {}, "outputs": [], "source": [ "m = [\n", " [0.5, 0.5, 0.0],\n", " [0.25, 0.5, 0.25],\n", " [0, 0.5, 0.5]\n", "]" ] }, { "cell_type": "code", "execution_count": 143, "metadata": {}, "outputs": [], "source": [ "m = torch.tensor(m)" ] }, { "cell_type": "code", "execution_count": 160, "metadata": {}, "outputs": [], "source": [ "n = torch.tensor([[0.0, 0.9, 0.1]])" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'n' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mNameError\u001b[0m: name 'n' is not defined" ] } ], "source": [ "for i in range(10):\n", " print(n)\n", " n = n @ m\n", " \n", " \n" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch.autograd import Variable" ] }, { "cell_type": "code", "execution_count": 161, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 162, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 170, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[1., 0., 0.]], grad_fn=),\n", " tensor([[-0.2796, 0.1884, 0.0913]]))" ] }, "execution_count": 170, "metadata": {}, "output_type": "execute_result" } ], "source": [ "logits = torch.tensor([[1., 1., 1.]], requires_grad=True)\n", "y = F.gumbel_softmax(logits, tau=1, hard=True)\n", "(y * torch.arange(3)).sum().backward()\n", "y, logits.grad" ] }, { "cell_type": "code", "execution_count": 148, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[1., 1.]], grad_fn=)" ] }, "execution_count": 148, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_soft = F.gumbel_softmax(logits, tau=1, hard=False)\n", "y_soft.max(-1, keepdim=True)[1] - y_soft.detach() + y_soft" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 117, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[1., 0.]], grad_fn=)" ] }, "execution_count": 117, "metadata": {}, "output_type": "execute_result" } ], "source": [ "F.gumbel_softmax(logits, tau=1, hard=True)" ] }, { "cell_type": "code", "execution_count": 81, "metadata": {}, "outputs": [], "source": [ "def sample_gumbel(shape, eps=1e-20):\n", " U = torch.rand(shape)\n", " return -Variable(torch.log(-torch.log(U + eps) + eps))\n", "\n", "def gumbel_softmax_sample(logits, temperature):\n", " y = logits + sample_gumbel(logits.size())\n", " return F.softmax(y / temperature, dim=-1)\n", "\n", "def gumbel_softmax(logits, temperature):\n", " \"\"\"\n", " ST-gumple-softmax\n", " input: [*, n_class]\n", " return: flatten --> [*, n_class] an one-hot vector\n", " \"\"\"\n", " y = gumbel_softmax_sample(logits, temperature)\n", " shape = y.size()\n", " _, ind = y.max(dim=-1)\n", " y_hard = torch.zeros_like(y).view(-1, shape[-1])\n", " y_hard.scatter_(1, ind.view(-1, 1), 1)\n", " y_hard = y_hard.view(*shape)\n", " y_hard = (y_hard - y).detach() + y\n", " return y_hard.view(-1,latent_dim*categorical_dim), y, ind" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'latent_dim' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mgumbel_softmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m\u001b[0m in \u001b[0;36mgumbel_softmax\u001b[0;34m(logits, temperature)\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0my_hard\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my_hard\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0my_hard\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0my_hard\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0my_hard\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mlatent_dim\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mcategorical_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mNameError\u001b[0m: name 'latent_dim' is not defined" ] } ], "source": [ "gumbel_softmax(torch.randn(4, 3), 0.1)" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [], "source": [ "latent_dim = 10\n", "categorical_dim = 3\n", "x = torch.randn(5, latent_dim, categorical_dim)" ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [], "source": [ "x.requires_grad=True" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [], "source": [ "y, k, ind = gumbel_softmax(x, 0.1)" ] }, { "cell_type": "code", "execution_count": 87, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[1, 2, 2, 2, 2, 0, 2, 1, 2, 1],\n", " [0, 2, 0, 0, 2, 1, 0, 0, 0, 0],\n", " [0, 0, 2, 1, 1, 1, 1, 0, 0, 0],\n", " [1, 2, 1, 2, 0, 1, 1, 1, 0, 0],\n", " [2, 0, 0, 1, 0, 2, 0, 1, 0, 1]])" ] }, "execution_count": 87, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ind" ] }, { "cell_type": "code", "execution_count": 89, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.return_types.max(\n", "values=tensor([[0.5745, 1.0000, 0.9985, 1.0000, 1.0000, 0.5934, 1.0000, 0.8247, 1.0000,\n", " 0.9999],\n", " [0.7558, 1.0000, 0.9999, 1.0000, 1.0000, 0.5294, 0.9994, 0.8727, 0.9999,\n", " 0.9213],\n", " [1.0000, 1.0000, 1.0000, 1.0000, 0.9942, 1.0000, 0.9996, 1.0000, 1.0000,\n", " 1.0000],\n", " [0.9993, 1.0000, 1.0000, 1.0000, 0.9998, 0.9998, 1.0000, 0.9661, 0.5564,\n", " 1.0000],\n", " [1.0000, 1.0000, 0.9999, 1.0000, 0.9785, 0.8706, 1.0000, 1.0000, 1.0000,\n", " 0.5304]], grad_fn=),\n", "indices=tensor([[1, 2, 2, 2, 2, 0, 2, 1, 2, 1],\n", " [0, 2, 0, 0, 2, 1, 0, 0, 0, 0],\n", " [0, 0, 2, 1, 1, 1, 1, 0, 0, 0],\n", " [1, 2, 1, 2, 0, 1, 1, 1, 0, 0],\n", " [2, 0, 0, 1, 0, 2, 0, 1, 0, 1]]))" ] }, "execution_count": 89, "metadata": {}, "output_type": "execute_result" } ], "source": [ "k.max(dim=-1)" ] }, { "cell_type": "code", "execution_count": 94, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([1, 2, 2, 2, 2, 0, 2, 1, 2, 1])" ] }, "execution_count": 94, "metadata": {}, "output_type": "execute_result" } ], "source": [ "k.max(dim=-1)[1][0]" ] }, { "cell_type": "code", "execution_count": 92, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([10])" ] }, "execution_count": 92, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.arange(10).shape" ] }, { "cell_type": "code", "execution_count": 171, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[1., 0., 0.]], grad_fn=)" ] }, "execution_count": 171, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y" ] }, { "cell_type": "code", "execution_count": 107, "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'nn' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0ml\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mNameError\u001b[0m: name 'nn' is not defined" ] } ], "source": [ "l = nn.embedding(3, 4)" ] }, { "cell_type": "code", "execution_count": 172, "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'nn' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mclass\u001b[0m \u001b[0mRewardModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mModule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mNameError\u001b[0m: name 'nn' is not defined" ] } ], "source": [ "class RewardModel(nn.Module):\n", "\n", " def __init__(self, model):\n", " super().__init__()\n", " self.embedding = model\n", " self.score = nn.Linear(model.embed_dim, 1, bias=False)\n", "\n", " def forward(self, x, seq_len=None):\n", " # x:表示文本,形状(B, T, C), seq_len:表示文本长度,形状(B)\n", " B, T = x.shape\n", " emb = self.embedding(x).last_hidden_state # (B, T, C)\n", " ind = torch.arange(B, device=x.device)\n", " if seq_len == None:\n", " seq_len = torch.tensor([T] * B)\n", " # 获取最后一个词元的特征\n", " pooled_emb = emb[ind, seq_len - 1] # (B, C)\n", " score = self.score(pooled_emb) # (B, 1)\n", " return score\n", "\n", "\n", "r_model = RewardModel(GPT2Model.from_pretrained('gpt2'))" ] }, { "cell_type": "code", "execution_count": 175, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(1, 4, requires_grad=True)" ] }, { "cell_type": "code", "execution_count": 176, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(1.7199, grad_fn=)" ] }, "execution_count": 176, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.max(x)" ] }, { "cell_type": "code", "execution_count": 181, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[1]])" ] }, "execution_count": 181, "metadata": {}, "output_type": "execute_result" } ], "source": [ "logits = torch.randn(1, 5, requires_grad=True)\n", "probs = F.softmax(logits, dim=-1)\n", "y = torch.multinomial(probs, num_samples=1)\n", "y" ] }, { "cell_type": "code", "execution_count": 182, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[1., 0., 0., 0., 0.]], grad_fn=)" ] }, "execution_count": 182, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_one_hot = F.gumbel_softmax(logits, hard=True)\n", "gumbel_y = torch.argmax(y_one_hot, dim=-1, keepdim=True)\n", "y_one_hot" ] }, { "cell_type": "code", "execution_count": 183, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0]])" ] }, "execution_count": 183, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gumbel_y" ] }, { "cell_type": "code", "execution_count": 189, "metadata": {}, "outputs": [], "source": [ "def log(t, eps = 1e-20):\n", " return torch.log(t.clamp(min = eps))\n", "\n", "def gumbel_noise(t):\n", " noise = torch.zeros_like(t).uniform_(0, 1)\n", " return -log(-log(noise))\n", "\n", "def gumbel_sample(t, temperature = 1., dim = -1):\n", " return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)" ] }, { "cell_type": "code", "execution_count": 190, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([2])" ] }, "execution_count": 190, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gumbel_sample(torch.randn(1, 4, requires_grad=True))" ] }, { "cell_type": "code", "execution_count": 191, "metadata": {}, "outputs": [], "source": [ "B, T, vs = (3, 4, 20)\n", "logits = torch.randn(B, T, vs)\n", "labels = torch.randint(vs, (B, T))" ] }, { "cell_type": "code", "execution_count": 197, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(3.9027), tensor(3.9027))" ] }, "execution_count": 197, "metadata": {}, "output_type": "execute_result" } ], "source": [ "l = F.cross_entropy(logits.view(B * T, vs), labels.view(B * T), reduction='none')\n", "l.mean(), F.cross_entropy(logits.view(B * T, vs), labels.view(B * T))" ] }, { "cell_type": "code", "execution_count": 203, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-4.3894, -4.6075, -2.8143, -3.5261],\n", " [-4.3982, -3.5565, -2.3154, -4.5414],\n", " [-4.1661, -3.0418, -2.9297, -6.5461]])" ] }, "execution_count": 203, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lnP = -F.cross_entropy(logits.transpose(-2, -1), labels, reduction='none')\n" ] }, { "cell_type": "code", "execution_count": 221, "metadata": {}, "outputs": [], "source": [ "lossi = F.cross_entropy(logits.transpose(-2, -1), labels, reduction='none')" ] }, { "cell_type": "code", "execution_count": 225, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(3.9027)" ] }, "execution_count": 225, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lossi.mean(-1).mean(-1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tokenizer" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 220, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(3.9027)" ] }, "execution_count": 220, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(-r * lnP).mean()" ] }, { "cell_type": "code", "execution_count": 211, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([-15.3374, -14.8116, -16.6836])" ] }, "execution_count": 211, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lnP.sum(-1)" ] }, { "cell_type": "code", "execution_count": 213, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0, 0, 0, 0, 1, 1, 1, 0, 1, 1],\n", " [0, 0, 1, 0, 1, 1, 0, 0, 1, 0],\n", " [0, 0, 1, 1, 1, 1, 1, 1, 1, 0],\n", " [0, 1, 1, 1, 1, 1, 0, 0, 0, 1],\n", " [1, 0, 1, 1, 1, 1, 1, 1, 0, 0],\n", " [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [1, 0, 0, 0, 0, 1, 1, 1, 0, 0],\n", " [0, 1, 1, 0, 0, 0, 0, 1, 1, 0],\n", " [0, 1, 1, 1, 1, 1, 0, 0, 1, 0],\n", " [1, 1, 1, 1, 0, 0, 0, 0, 1, 1],\n", " [0, 1, 0, 1, 1, 1, 1, 1, 0, 1],\n", " [0, 1, 0, 0, 1, 1, 0, 1, 0, 1],\n", " [0, 0, 1, 1, 0, 1, 0, 0, 0, 0],\n", " [0, 0, 1, 1, 1, 0, 0, 0, 1, 0],\n", " [0, 1, 1, 0, 0, 1, 0, 1, 1, 0],\n", " [0, 0, 0, 0, 0, 1, 1, 1, 1, 0],\n", " [1, 0, 0, 1, 0, 0, 1, 0, 1, 1],\n", " [1, 1, 1, 0, 0, 0, 0, 0, 0, 1],\n", " [1, 1, 1, 0, 1, 1, 1, 1, 0, 0],\n", " [0, 0, 0, 1, 1, 1, 1, 1, 0, 1]])" ] }, "execution_count": 213, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.randint(2, (20, 10))" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "a = torch.tensor([0.5, 0.2, 0.1])\n", "r = torch.tensor([1001., 1002., 1003.])" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([500.5000, 200.4000, 100.3000]), tensor(208.2627))" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a * r, (a * r).std()" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([0.0000, 0.2000, 0.2000]), tensor(0.1155))" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a * (r-1001), (a * (r-1001)).std()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0., 1., 2.])" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "r-1001" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(208.1666)" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(1000 * a).std()" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.2082)" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(1 * a).std()" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([1001.0000, 400.4000, 200.2000])" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "1000 * a + 1002 * a" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([1.5000, 0.6000, 0.3000])" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a + 2 * a" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 20])" ] }, "execution_count": 64, "metadata": {}, "output_type": "execute_result" } ], "source": [ "g = torch.normal(1, 1, (2, 20))\n", "g.shape" ] }, { "cell_type": "code", "execution_count": 78, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 1])" ] }, "execution_count": 78, "metadata": {}, "output_type": "execute_result" } ], "source": [ "r = torch.normal(1, 1, (2, 1))\n", "r.shape" ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.9881, 0.7762, 1.3812, 0.7958, 1.0094, 0.1070, 0.1709, 0.3961,\n", " -0.0118, 0.2604, 0.5242, 0.2400, 0.1989, 0.2737, 1.4879, 0.4091,\n", " 0.7199, -0.2781, 0.7418, 1.2079],\n", " [ 3.4830, 0.1242, 0.8680, 0.0464, 0.8523, 1.0290, -0.3891, 2.5647,\n", " -0.1333, 3.3386, 0.1980, 0.2184, 0.4268, 4.9361, 3.0060, 1.3053,\n", " 2.5407, 1.2318, 1.8263, 1.0207]])" ] }, "execution_count": 79, "metadata": {}, "output_type": "execute_result" } ], "source": [ "help(g * r)" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Help on method norm in module torch._tensor:\n", "\n", "norm(p: Union[float, str, NoneType] = 'fro', dim=None, keepdim=False, dtype=None) method of torch.Tensor instance\n", " See :func:`torch.norm`\n", "\n" ] } ], "source": [ "help(g.norm)" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(1.4053)" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(g * rr).std()" ] }, { "cell_type": "code", "execution_count": 88, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([39.9750, 37.7648])" ] }, "execution_count": 88, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(g**2).sum(dim=-1)" ] }, { "cell_type": "code", "execution_count": 179, "metadata": {}, "outputs": [], "source": [ "def norm_(g):\n", " k = g / g.norm(dim=-1, keepdim=True)\n", " return k" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(99.9496)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "num = 10000\n", "grad = torch.normal(0, 1, (num, 100))\n", "g = torch.normal(100, 1, (num, 1))\n", "(grad * g).std()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(1.0004)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "num = 10000\n", "grad = torch.normal(0, 1, (num, 100))\n", "g = torch.normal(100, 1, (num, 1)) - 100\n", "(grad * g).std()" ] }, { "cell_type": "code", "execution_count": 261, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.0996)" ] }, "execution_count": 261, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(g * r).mean(-1).std()" ] }, { "cell_type": "code", "execution_count": 241, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[100.],\n", " [100.],\n", " [100.],\n", " [100.],\n", " [100.],\n", " [100.],\n", " [100.],\n", " [100.],\n", " [100.],\n", " [100.]])" ] }, "execution_count": 241, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.normal(100, 0, (10, 1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }