|
|
@@ -1,2994 +0,0 @@
|
|
|
-{
|
|
|
- "cells": [
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 115,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "name": "stdout",
|
|
|
- "output_type": "stream",
|
|
|
- "text": [
|
|
|
- "\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n",
|
|
|
- "\u001b[0mCollecting trl\n",
|
|
|
- " Using cached trl-0.7.4-py3-none-any.whl.metadata (10 kB)\n",
|
|
|
- "Requirement already satisfied: torch>=1.4.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (2.0.1)\n",
|
|
|
- "Requirement already satisfied: transformers>=4.18.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (4.31.0)\n",
|
|
|
- "Requirement already satisfied: numpy>=1.18.2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (1.24.4)\n",
|
|
|
- "Requirement already satisfied: accelerate in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (0.21.0)\n",
|
|
|
- "Requirement already satisfied: datasets in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (2.14.2)\n",
|
|
|
- "Requirement already satisfied: tyro>=0.5.11 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from trl) (0.5.12)\n",
|
|
|
- "Requirement already satisfied: filelock in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch>=1.4.0->trl) (3.0.12)\n",
|
|
|
- "Requirement already satisfied: typing-extensions in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch>=1.4.0->trl) (4.8.0)\n",
|
|
|
- "Requirement already satisfied: sympy in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch>=1.4.0->trl) (1.6.2)\n",
|
|
|
- "Requirement already satisfied: networkx in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch>=1.4.0->trl) (2.5)\n",
|
|
|
- "Requirement already satisfied: jinja2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch>=1.4.0->trl) (2.11.2)\n",
|
|
|
- "Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (0.16.4)\n",
|
|
|
- "Requirement already satisfied: packaging>=20.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (23.1)\n",
|
|
|
- "Requirement already satisfied: pyyaml>=5.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (5.3.1)\n",
|
|
|
- "Requirement already satisfied: regex!=2019.12.17 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (2020.10.15)\n",
|
|
|
- "Requirement already satisfied: requests in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (2.24.0)\n",
|
|
|
- "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (0.13.3)\n",
|
|
|
- "Requirement already satisfied: safetensors>=0.3.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (0.3.1)\n",
|
|
|
- "Requirement already satisfied: tqdm>=4.27 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from transformers>=4.18.0->trl) (4.65.0)\n",
|
|
|
- "Requirement already satisfied: docstring-parser>=0.14.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from tyro>=0.5.11->trl) (0.15)\n",
|
|
|
- "Requirement already satisfied: rich>=11.1.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from tyro>=0.5.11->trl) (13.6.0)\n",
|
|
|
- "Requirement already satisfied: shtab>=1.5.6 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from tyro>=0.5.11->trl) (1.6.4)\n",
|
|
|
- "Requirement already satisfied: psutil in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from accelerate->trl) (5.7.2)\n",
|
|
|
- "Requirement already satisfied: pyarrow>=8.0.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (12.0.1)\n",
|
|
|
- "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (0.3.7)\n",
|
|
|
- "Requirement already satisfied: pandas in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (2.0.3)\n",
|
|
|
- "Requirement already satisfied: xxhash in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (3.3.0)\n",
|
|
|
- "Requirement already satisfied: multiprocess in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (0.70.15)\n",
|
|
|
- "Requirement already satisfied: fsspec>=2021.11.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from fsspec[http]>=2021.11.1->datasets->trl) (2023.6.0)\n",
|
|
|
- "Requirement already satisfied: aiohttp in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from datasets->trl) (3.8.5)\n",
|
|
|
- "Requirement already satisfied: attrs>=17.3.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (23.1.0)\n",
|
|
|
- "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (3.2.0)\n",
|
|
|
- "Requirement already satisfied: multidict<7.0,>=4.5 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (6.0.4)\n",
|
|
|
- "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (4.0.2)\n",
|
|
|
- "Requirement already satisfied: yarl<2.0,>=1.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (1.9.2)\n",
|
|
|
- "Requirement already satisfied: frozenlist>=1.1.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (1.4.0)\n",
|
|
|
- "Requirement already satisfied: aiosignal>=1.1.2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from aiohttp->datasets->trl) (1.3.1)\n",
|
|
|
- "Requirement already satisfied: chardet<4,>=3.0.2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers>=4.18.0->trl) (3.0.4)\n",
|
|
|
- "Requirement already satisfied: idna<3,>=2.5 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers>=4.18.0->trl) (2.10)\n",
|
|
|
- "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers>=4.18.0->trl) (1.25.11)\n",
|
|
|
- "Requirement already satisfied: certifi>=2017.4.17 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers>=4.18.0->trl) (2020.6.20)\n",
|
|
|
- "Requirement already satisfied: markdown-it-py>=2.2.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (3.0.0)\n",
|
|
|
- "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (2.16.1)\n",
|
|
|
- "Requirement already satisfied: MarkupSafe>=0.23 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from jinja2->torch>=1.4.0->trl) (1.1.1)\n",
|
|
|
- "Requirement already satisfied: decorator>=4.3.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from networkx->torch>=1.4.0->trl) (4.4.2)\n",
|
|
|
- "Requirement already satisfied: python-dateutil>=2.8.2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from pandas->datasets->trl) (2.8.2)\n",
|
|
|
- "Requirement already satisfied: pytz>=2020.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from pandas->datasets->trl) (2020.1)\n",
|
|
|
- "Requirement already satisfied: tzdata>=2022.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from pandas->datasets->trl) (2023.3)\n",
|
|
|
- "Requirement already satisfied: mpmath>=0.19 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from sympy->torch>=1.4.0->trl) (1.1.0)\n",
|
|
|
- "Requirement already satisfied: mdurl~=0.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro>=0.5.11->trl) (0.1.2)\n",
|
|
|
- "Requirement already satisfied: six>=1.5 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from python-dateutil>=2.8.2->pandas->datasets->trl) (1.15.0)\n",
|
|
|
- "Using cached trl-0.7.4-py3-none-any.whl (133 kB)\n",
|
|
|
- "\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n",
|
|
|
- "\u001b[0m\u001b[33mDEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063\u001b[0m\u001b[33m\n",
|
|
|
- "\u001b[0mInstalling collected packages: trl\n"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "name": "stdout",
|
|
|
- "output_type": "stream",
|
|
|
- "text": [
|
|
|
- "Successfully installed trl-0.7.4\r\n"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "!pip install --user trl"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 5,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "import torch\n",
|
|
|
- "import torch.nn.functional as F"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 7,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "a = torch.randn(1, 4, requires_grad=True)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 8,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([[ 0.2511, -0.2847, -1.3987, 0.9078]], requires_grad=True)"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 8,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "a"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 9,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "s = torch.argmax(a)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 10,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor(3)"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 10,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "s"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 19,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([ 0., 10., 3., 0.], requires_grad=True)"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 19,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "weights = torch.tensor([0, 10, 3, 0], dtype=torch.float, requires_grad=True) # create a tensor of weights\n",
|
|
|
- "weights"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 20,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "s = torch.multinomial(weights, 1)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 21,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([2])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 21,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "s"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 22,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "ename": "RuntimeError",
|
|
|
- "evalue": "element 0 of tensors does not require grad and does not have a grad_fn",
|
|
|
- "output_type": "error",
|
|
|
- "traceback": [
|
|
|
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
|
- "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
|
|
- "\u001b[0;32m<ipython-input-22-12626f95b38d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0ms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
|
- "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 485\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 486\u001b[0m )\n\u001b[0;32m--> 487\u001b[0;31m torch.autograd.backward(\n\u001b[0m\u001b[1;32m 488\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 489\u001b[0m )\n",
|
|
|
- "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m 201\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 202\u001b[0m allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass\n",
|
|
|
- "\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "s.backward()"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 226,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "import torch\n",
|
|
|
- "from transformers import AutoTokenizer, GPT2LMHeadModel\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "torch.manual_seed(12046)\n",
|
|
|
- "\n",
|
|
|
- "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 227,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "50257"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 227,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "tokenizer.vocab_size"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 231,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "1.027730370438185e+47"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 231,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "import math\n",
|
|
|
- "\n",
|
|
|
- "math.pow(50256, 10)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 55,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "name": "stderr",
|
|
|
- "output_type": "stream",
|
|
|
- "text": [
|
|
|
- "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n",
|
|
|
- "Input length of input_ids is 7, but `max_length` is set to 1. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.\n"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "res = model.generate(**ids, max_length=1)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 56,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([[2061, 318, 262, 3139, 286, 2807, 30, 198]])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 56,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "res"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 53,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "l = model(**ids).logits"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 54,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([[[ -37.2172, -36.8864, -40.3563, ..., -43.4351, -43.0232,\n",
|
|
|
- " -37.0993],\n",
|
|
|
- " [ -84.7418, -82.6453, -88.2857, ..., -89.4501, -90.0555,\n",
|
|
|
- " -84.6770],\n",
|
|
|
- " [ -94.2372, -92.2747, -95.0347, ..., -94.9362, -97.9820,\n",
|
|
|
- " -91.8550],\n",
|
|
|
- " ...,\n",
|
|
|
- " [ -75.3586, -73.5543, -77.7737, ..., -80.5165, -81.9881,\n",
|
|
|
- " -76.4165],\n",
|
|
|
- " [ -77.2812, -75.8471, -80.8869, ..., -88.2672, -86.0258,\n",
|
|
|
- " -79.5727],\n",
|
|
|
- " [-123.9086, -123.2837, -124.5704, ..., -135.1090, -134.9653,\n",
|
|
|
- " -119.7716]]], grad_fn=<UnsafeViewBackward0>)"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 54,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "l"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 60,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([[[ -37.2172, -36.8864, -40.3563, ..., -43.4351, -43.0232,\n",
|
|
|
- " -37.0993],\n",
|
|
|
- " [ -84.7418, -82.6453, -88.2857, ..., -89.4501, -90.0555,\n",
|
|
|
- " -84.6770],\n",
|
|
|
- " [ -94.2372, -92.2747, -95.0347, ..., -94.9362, -97.9820,\n",
|
|
|
- " -91.8550],\n",
|
|
|
- " ...,\n",
|
|
|
- " [ -75.3586, -73.5543, -77.7737, ..., -80.5165, -81.9881,\n",
|
|
|
- " -76.4165],\n",
|
|
|
- " [ -77.2812, -75.8471, -80.8869, ..., -88.2672, -86.0258,\n",
|
|
|
- " -79.5727],\n",
|
|
|
- " [-123.9086, -123.2837, -124.5704, ..., -135.1090, -134.9653,\n",
|
|
|
- " -119.7716]]])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 60,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "with torch.no_grad():\n",
|
|
|
- " ll = model(**ids).logits\n",
|
|
|
- "ll"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 87,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "weights = torch.tensor([0, 10, 3, 0], dtype=torch.float, requires_grad=True)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 88,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([1])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 88,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "y = torch.multinomial(weights, 1)\n",
|
|
|
- "y"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 89,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "ename": "RuntimeError",
|
|
|
- "evalue": "element 0 of tensors does not require grad and does not have a grad_fn",
|
|
|
- "output_type": "error",
|
|
|
- "traceback": [
|
|
|
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
|
- "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
|
|
- "\u001b[0;32m<ipython-input-89-ab75bb780f4c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
|
- "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 485\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 486\u001b[0m )\n\u001b[0;32m--> 487\u001b[0;31m torch.autograd.backward(\n\u001b[0m\u001b[1;32m 488\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 489\u001b[0m )\n",
|
|
|
- "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m 201\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 202\u001b[0m allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass\n",
|
|
|
- "\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "y.backward()"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 91,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "name": "stderr",
|
|
|
- "output_type": "stream",
|
|
|
- "text": [
|
|
|
- "<ipython-input-91-26e670d9452d>:1: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
|
|
|
- " F.softmax(weights)[y]\n"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([0.9990], grad_fn=<IndexBackward0>)"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 91,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "F.softmax(weights)[y]"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 82,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([2])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 82,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "y"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 86,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "k = (weights == weights[y]).nonzero()"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 51,
|
|
|
- "metadata": {
|
|
|
- "scrolled": true
|
|
|
- },
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "ename": "RuntimeError",
|
|
|
- "evalue": "only Tensors of floating point and complex dtype can require gradients",
|
|
|
- "output_type": "error",
|
|
|
- "traceback": [
|
|
|
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
|
- "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
|
|
- "\u001b[0;32m<ipython-input-51-e107420eacad>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mids\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'input_ids'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_grad\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
|
- "\u001b[0;31mRuntimeError\u001b[0m: only Tensors of floating point and complex dtype can require gradients"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "ids['input_ids'].requires_grad=True"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 49,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "torch.Tensor"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 49,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": []
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 50,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "torch.Tensor"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 50,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "type(a)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": []
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": []
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": []
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 57,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor(1.0545, grad_fn=<SumBackward0>)"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 57,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "a = torch.rand(1, 4)\n",
|
|
|
- "a.requires_grad=True\n",
|
|
|
- "y = torch.sum(a)\n",
|
|
|
- "y"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 46,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "torch.sum(a).backward()"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 37,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor(8903)"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 37,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "y = torch.sum(ids['input_ids'])\n",
|
|
|
- "y"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 38,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "ename": "RuntimeError",
|
|
|
- "evalue": "element 0 of tensors does not require grad and does not have a grad_fn",
|
|
|
- "output_type": "error",
|
|
|
- "traceback": [
|
|
|
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
|
- "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
|
|
- "\u001b[0;32m<ipython-input-38-ab75bb780f4c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
|
- "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 485\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 486\u001b[0m )\n\u001b[0;32m--> 487\u001b[0;31m torch.autograd.backward(\n\u001b[0m\u001b[1;32m 488\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 489\u001b[0m )\n",
|
|
|
- "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m 201\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 202\u001b[0m allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass\n",
|
|
|
- "\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "y.backward()"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 94,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "x = torch.randn(3, 2, requires_grad=True)\n",
|
|
|
- "y = torch.ones(3, 2)\n",
|
|
|
- "x\n",
|
|
|
- "z = torch.where(x > 0, 1.0, 0.0)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 95,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([[0., 1.],\n",
|
|
|
- " [1., 1.],\n",
|
|
|
- " [0., 0.]])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 95,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "z"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 103,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "t = (x - 0 > 0.001) * 1 + 0"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 100,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([[-1.4747, 0.2449],\n",
|
|
|
- " [ 0.8045, 0.4111],\n",
|
|
|
- " [-0.7769, -0.4431]], requires_grad=True)"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 100,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "x"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 104,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([[0, 1],\n",
|
|
|
- " [1, 1],\n",
|
|
|
- " [0, 0]])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 104,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "t"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 106,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([[False, True],\n",
|
|
|
- " [ True, True],\n",
|
|
|
- " [False, False]])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 106,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "(x - 0 > 0.001)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 108,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([[-1.4747, 0.2449],\n",
|
|
|
- " [ 0.8045, 0.4111],\n",
|
|
|
- " [-0.7769, -0.4431]], requires_grad=True)"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 108,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "x"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 109,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor(0.8045, grad_fn=<MaxBackward1>)"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 109,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "torch.max(x)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 117,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "ename": "ModuleNotFoundError",
|
|
|
- "evalue": "No module named 'trl'",
|
|
|
- "output_type": "error",
|
|
|
- "traceback": [
|
|
|
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
|
- "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
|
|
- "\u001b[0;32m<ipython-input-117-8b0f0cf4af74>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtrl\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mAutoModelForCausalLMWithValueHea\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
|
- "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'trl'"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "from trl import AutoModelForCausalLMWithValueHea"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 1,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "name": "stderr",
|
|
|
- "output_type": "stream",
|
|
|
- "text": [
|
|
|
- "/Users/tgbaggio/.local/lib/python3.8/site-packages/trl/trainer/ppo_config.py:141: UserWarning: The `optimize_cuda_cache` arguement will be deprecated soon, please use `optimize_device_cache` instead.\n",
|
|
|
- " warnings.warn(\n"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "import trl"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 2,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "from trl import AutoModelForCausalLMWithValueHead"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 4,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "application/vnd.jupyter.widget-view+json": {
|
|
|
- "model_id": "c8eb15f8baa64dbfa6a41e4349bc1068",
|
|
|
- "version_major": 2,
|
|
|
- "version_minor": 0
|
|
|
- },
|
|
|
- "text/plain": [
|
|
|
- "Downloading (…)lve/main/config.json: 0%| | 0.00/577 [00:00<?, ?B/s]"
|
|
|
- ]
|
|
|
- },
|
|
|
- "metadata": {},
|
|
|
- "output_type": "display_data"
|
|
|
- },
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "application/vnd.jupyter.widget-view+json": {
|
|
|
- "model_id": "c92409e5bc944693b5578571245093bf",
|
|
|
- "version_major": 2,
|
|
|
- "version_minor": 0
|
|
|
- },
|
|
|
- "text/plain": [
|
|
|
- "Downloading pytorch_model.bin: 0%| | 0.00/548M [00:00<?, ?B/s]"
|
|
|
- ]
|
|
|
- },
|
|
|
- "metadata": {},
|
|
|
- "output_type": "display_data"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "model = AutoModelForCausalLMWithValueHead.from_pretrained('lvwerra/gpt2-imdb')"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 5,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "AutoModelForCausalLMWithValueHead(\n",
|
|
|
- " (pretrained_model): GPT2LMHeadModel(\n",
|
|
|
- " (transformer): GPT2Model(\n",
|
|
|
- " (wte): Embedding(50257, 768)\n",
|
|
|
- " (wpe): Embedding(1024, 768)\n",
|
|
|
- " (drop): Dropout(p=0.1, inplace=False)\n",
|
|
|
- " (h): ModuleList(\n",
|
|
|
- " (0-11): 12 x GPT2Block(\n",
|
|
|
- " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
|
|
- " (attn): GPT2Attention(\n",
|
|
|
- " (c_attn): Conv1D()\n",
|
|
|
- " (c_proj): Conv1D()\n",
|
|
|
- " (attn_dropout): Dropout(p=0.1, inplace=False)\n",
|
|
|
- " (resid_dropout): Dropout(p=0.1, inplace=False)\n",
|
|
|
- " )\n",
|
|
|
- " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
|
|
- " (mlp): GPT2MLP(\n",
|
|
|
- " (c_fc): Conv1D()\n",
|
|
|
- " (c_proj): Conv1D()\n",
|
|
|
- " (act): NewGELUActivation()\n",
|
|
|
- " (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
|
- " )\n",
|
|
|
- " )\n",
|
|
|
- " )\n",
|
|
|
- " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
|
|
- " )\n",
|
|
|
- " (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
|
|
|
- " )\n",
|
|
|
- " (v_head): ValueHead(\n",
|
|
|
- " (dropout): Dropout(p=0.1, inplace=False)\n",
|
|
|
- " (summary): Linear(in_features=768, out_features=1, bias=True)\n",
|
|
|
- " (flatten): Flatten(start_dim=1, end_dim=-1)\n",
|
|
|
- " )\n",
|
|
|
- ")"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 5,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "model"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 10,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "l = model(**ids)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 13,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "torch.Size([1, 7, 50257])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 13,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "l[0].shape"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 15,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "torch.Size([1, 7])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 15,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "l[2].shape"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": []
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 16,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "from transformers import AutoTokenizer, pipeline"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 17,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "application/vnd.jupyter.widget-view+json": {
|
|
|
- "model_id": "a14f54b9e20d4fa2bf97d5d707531fcf",
|
|
|
- "version_major": 2,
|
|
|
- "version_minor": 0
|
|
|
- },
|
|
|
- "text/plain": [
|
|
|
- "Downloading (…)lve/main/config.json: 0%| | 0.00/735 [00:00<?, ?B/s]"
|
|
|
- ]
|
|
|
- },
|
|
|
- "metadata": {},
|
|
|
- "output_type": "display_data"
|
|
|
- },
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "application/vnd.jupyter.widget-view+json": {
|
|
|
- "model_id": "1397323a065743a19dc8e235b0a49018",
|
|
|
- "version_major": 2,
|
|
|
- "version_minor": 0
|
|
|
- },
|
|
|
- "text/plain": [
|
|
|
- "Downloading pytorch_model.bin: 0%| | 0.00/268M [00:00<?, ?B/s]"
|
|
|
- ]
|
|
|
- },
|
|
|
- "metadata": {},
|
|
|
- "output_type": "display_data"
|
|
|
- },
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "application/vnd.jupyter.widget-view+json": {
|
|
|
- "model_id": "898a7c9e19d549459bd8a92006ae3f2b",
|
|
|
- "version_major": 2,
|
|
|
- "version_minor": 0
|
|
|
- },
|
|
|
- "text/plain": [
|
|
|
- "Downloading (…)okenizer_config.json: 0%| | 0.00/333 [00:00<?, ?B/s]"
|
|
|
- ]
|
|
|
- },
|
|
|
- "metadata": {},
|
|
|
- "output_type": "display_data"
|
|
|
- },
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "application/vnd.jupyter.widget-view+json": {
|
|
|
- "model_id": "123feae78d4f4bc182d7405985f33462",
|
|
|
- "version_major": 2,
|
|
|
- "version_minor": 0
|
|
|
- },
|
|
|
- "text/plain": [
|
|
|
- "Downloading (…)solve/main/vocab.txt: 0%| | 0.00/232k [00:00<?, ?B/s]"
|
|
|
- ]
|
|
|
- },
|
|
|
- "metadata": {},
|
|
|
- "output_type": "display_data"
|
|
|
- },
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "application/vnd.jupyter.widget-view+json": {
|
|
|
- "model_id": "fbf7ebeb54244310bfe2d225e1441db4",
|
|
|
- "version_major": 2,
|
|
|
- "version_minor": 0
|
|
|
- },
|
|
|
- "text/plain": [
|
|
|
- "Downloading (…)/main/tokenizer.json: 0%| | 0.00/466k [00:00<?, ?B/s]"
|
|
|
- ]
|
|
|
- },
|
|
|
- "metadata": {},
|
|
|
- "output_type": "display_data"
|
|
|
- },
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "application/vnd.jupyter.widget-view+json": {
|
|
|
- "model_id": "7170262df42543be877e56798b06a64a",
|
|
|
- "version_major": 2,
|
|
|
- "version_minor": 0
|
|
|
- },
|
|
|
- "text/plain": [
|
|
|
- "Downloading (…)cial_tokens_map.json: 0%| | 0.00/112 [00:00<?, ?B/s]"
|
|
|
- ]
|
|
|
- },
|
|
|
- "metadata": {},
|
|
|
- "output_type": "display_data"
|
|
|
- },
|
|
|
- {
|
|
|
- "name": "stderr",
|
|
|
- "output_type": "stream",
|
|
|
- "text": [
|
|
|
- "Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\n",
|
|
|
- "pip install xformers.\n"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "p = pipeline('sentiment-analysis', 'lvwerra/distilbert-imdb')"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 21,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "sent_kwargs = {\"return_all_scores\": True, \"function_to_apply\": \"none\", \"batch_size\": 16}"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 29,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "{'input_ids': tensor([[2061, 318, 262, 3139, 286, 2807, 30]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 29,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "question = 'What is the capital of China?'\n",
|
|
|
- "ids = tokenizer(question, return_tensors=\"pt\")\n",
|
|
|
- "ids"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": []
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 30,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "pipe_outputs = p(question, **sent_kwargs)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 31,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "[[{'label': 'NEGATIVE', 'score': 0.28462526202201843},\n",
|
|
|
- " {'label': 'POSITIVE', 'score': -0.498414009809494}]]"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 31,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "pipe_outputs"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 32,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "score = torch.randn(3)\n",
|
|
|
- "log = torch.randn(3, 7)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 45,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "name": "stdout",
|
|
|
- "output_type": "stream",
|
|
|
- "text": [
|
|
|
- "tensor(2.3532) torch.Size([7])\n",
|
|
|
- "tensor(-0.3089) torch.Size([7])\n",
|
|
|
- "tensor(-1.0061) torch.Size([7])\n"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "for s, l in zip(score, log):\n",
|
|
|
- " print(s, l.shape)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 46,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "x = torch.randn(4)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 47,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([ 0.8624, 0.1586, -1.3384, 0.4741])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 47,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "x"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 48,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "x += 1"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 49,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([ 1.8624, 1.1586, -0.3384, 1.4741])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 49,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "x"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 51,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "x[2] += 1"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 55,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([[-0.0498, 0.3148, -0.6248, 0.0660],\n",
|
|
|
- " [ 1.1840, -1.0514, -0.8765, -0.0313],\n",
|
|
|
- " [ 2.3417, -0.0704, -0.0996, 1.7966]])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 55,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "x = torch.randn(3, 4)\n",
|
|
|
- "x"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 56,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "ename": "ValueError",
|
|
|
- "evalue": "step must be greater than zero",
|
|
|
- "output_type": "error",
|
|
|
- "traceback": [
|
|
|
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
|
- "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
|
|
|
- "\u001b[0;32m<ipython-input-56-fc433e12835a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
|
- "\u001b[0;31mValueError\u001b[0m: step must be greater than zero"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "x[::-1]"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 79,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "self_config_gamma = 0.8\n",
|
|
|
- "self_config_lam = 0.9\n",
|
|
|
- "\n",
|
|
|
- "def compute_advantages(\n",
|
|
|
- " values: torch.FloatTensor,\n",
|
|
|
- " rewards: torch.FloatTensor):\n",
|
|
|
- " lastgaelam = 0\n",
|
|
|
- " advantages_reversed = []\n",
|
|
|
- " gen_len = rewards.shape[-1]\n",
|
|
|
- " for t in reversed(range(gen_len)):\n",
|
|
|
- " nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0\n",
|
|
|
- " delta = rewards[:, t] + self_config_gamma * nextvalues - values[:, t]\n",
|
|
|
- " lastgaelam = delta + self_config_gamma * self_config_lam * lastgaelam\n",
|
|
|
- " advantages_reversed.append(lastgaelam)\n",
|
|
|
- " advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)\n",
|
|
|
- "\n",
|
|
|
- " returns = advantages + values\n",
|
|
|
- " advantages = advantages.detach()\n",
|
|
|
- " return values, advantages, returns\n"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 80,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "values = torch.tensor([[3., 4., 2.]])\n",
|
|
|
- "rewards = torch.tensor([[-1., 2., 2.]])"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 81,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "(tensor([[3., 4., 2.]]),\n",
|
|
|
- " tensor([[-1.0880, -0.4000, 0.0000]]),\n",
|
|
|
- " tensor([[1.9120, 3.6000, 2.0000]]))"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 81,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "compute_advantages(values, rewards)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 1,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "1.8800000000000003"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 1,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "0.8 * 2 + 0.8 ** 2 * 2 - 1"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 5,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "import gym\n",
|
|
|
- "import pygame\n",
|
|
|
- "\n",
|
|
|
- "env = gym.make('MountainCar-v0', render_mode=\"human\")"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 6,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "name": "stdout",
|
|
|
- "output_type": "stream",
|
|
|
- "text": [
|
|
|
- "The observation space: Box([-1.2 -0.07], [0.6 0.07], (2,), float32)\n",
|
|
|
- "The action space: Discrete(3)\n"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "# Observation and action space \n",
|
|
|
- "obs_space = env.observation_space\n",
|
|
|
- "action_space = env.action_space\n",
|
|
|
- "print(\"The observation space: {}\".format(obs_space))\n",
|
|
|
- "print(\"The action space: {}\".format(action_space))"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 7,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "name": "stdout",
|
|
|
- "output_type": "stream",
|
|
|
- "text": [
|
|
|
- "The initial observation is (array([-0.46305716, 0. ], dtype=float32), {})\n",
|
|
|
- "0\n",
|
|
|
- "The new observation is [-0.46450874 -0.00145157]\n"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "import matplotlib.pyplot as plt \n",
|
|
|
- "\n",
|
|
|
- "# reset the environment and see the initial observation\n",
|
|
|
- "obs = env.reset()\n",
|
|
|
- "print(\"The initial observation is {}\".format(obs))\n",
|
|
|
- "\n",
|
|
|
- "# Sample a random action from the entire action space\n",
|
|
|
- "random_action = env.action_space.sample()\n",
|
|
|
- "print(random_action)\n",
|
|
|
- "\n",
|
|
|
- "# # Take the action and get the new observation space\n",
|
|
|
- "new_obs, reward, done, _, info = env.step(random_action)\n",
|
|
|
- "env.render() \n",
|
|
|
- "pygame.display.update()\n",
|
|
|
- "print(\"The new observation is {}\".format(new_obs))"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 16,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "name": "stdout",
|
|
|
- "output_type": "stream",
|
|
|
- "text": [
|
|
|
- "2\n",
|
|
|
- "The new observation is [-5.2235210e-01 -4.3133463e-04]\n"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "# Sample a random action from the entire action space\n",
|
|
|
- "random_action = env.action_space.sample()\n",
|
|
|
- "print(random_action)\n",
|
|
|
- "\n",
|
|
|
- "# # Take the action and get the new observation space\n",
|
|
|
- "new_obs, reward, done, _, info = env.step(random_action)\n",
|
|
|
- "print(\"The new observation is {}\".format(new_obs))"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 17,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "env.close()\n"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 19,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "<function pygame.event.get>"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 19,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "import pygame\n",
|
|
|
- "\n",
|
|
|
- "pygame.event.get"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 1,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "ename": "DependencyNotInstalled",
|
|
|
- "evalue": "Box2D is not installed, run `pip install gymnasium[box2d]`",
|
|
|
- "output_type": "error",
|
|
|
- "traceback": [
|
|
|
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
|
- "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
|
|
- "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/gymnasium/envs/box2d/bipedal_walker.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mBox2D\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m from Box2D.b2 import (\n",
|
|
|
- "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'Box2D'",
|
|
|
- "\nThe above exception was the direct cause of the following exception:\n",
|
|
|
- "\u001b[0;31mDependencyNotInstalled\u001b[0m Traceback (most recent call last)",
|
|
|
- "\u001b[0;32m<ipython-input-1-c5175ad63d7c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mgymnasium\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"LunarLander-v2\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrender_mode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"human\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mobservation\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minfo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
|
- "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/gymnasium/envs/registration.py\u001b[0m in \u001b[0;36mmake\u001b[0;34m(id, max_episode_steps, autoreset, apply_api_compatibility, disable_env_checker, **kwargs)\u001b[0m\n\u001b[1;32m 754\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 755\u001b[0m \u001b[0;31m# Assume it's a string\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 756\u001b[0;31m \u001b[0menv_creator\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_env_creator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv_spec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mentry_point\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 757\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 758\u001b[0m \u001b[0;31m# Determine if to use the rendering\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
|
- "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/gymnasium/envs/registration.py\u001b[0m in \u001b[0;36mload_env_creator\u001b[0;34m(name)\u001b[0m\n\u001b[1;32m 543\u001b[0m \"\"\"\n\u001b[1;32m 544\u001b[0m \u001b[0mmod_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattr_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\":\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 545\u001b[0;31m \u001b[0mmod\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimportlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimport_module\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmod_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 546\u001b[0m \u001b[0mfn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattr_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 547\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
|
- "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/__init__.py\u001b[0m in \u001b[0;36mimport_module\u001b[0;34m(name, package)\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[0mlevel\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 127\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_bootstrap\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_gcd_import\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlevel\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpackage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlevel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 128\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
|
- "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_gcd_import\u001b[0;34m(name, package, level)\u001b[0m\n",
|
|
|
- "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load\u001b[0;34m(name, import_)\u001b[0m\n",
|
|
|
- "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load_unlocked\u001b[0;34m(name, import_)\u001b[0m\n",
|
|
|
- "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_call_with_frames_removed\u001b[0;34m(f, *args, **kwds)\u001b[0m\n",
|
|
|
- "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_gcd_import\u001b[0;34m(name, package, level)\u001b[0m\n",
|
|
|
- "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load\u001b[0;34m(name, import_)\u001b[0m\n",
|
|
|
- "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_find_and_load_unlocked\u001b[0;34m(name, import_)\u001b[0m\n",
|
|
|
- "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_load_unlocked\u001b[0;34m(spec)\u001b[0m\n",
|
|
|
- "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap_external.py\u001b[0m in \u001b[0;36mexec_module\u001b[0;34m(self, module)\u001b[0m\n",
|
|
|
- "\u001b[0;32m~/opt/anaconda3/lib/python3.8/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_call_with_frames_removed\u001b[0;34m(f, *args, **kwds)\u001b[0m\n",
|
|
|
- "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/gymnasium/envs/box2d/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mgymnasium\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbox2d\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbipedal_walker\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mBipedalWalker\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mBipedalWalkerHardcore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgymnasium\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbox2d\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcar_racing\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCarRacing\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgymnasium\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbox2d\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlunar_lander\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mLunarLander\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLunarLanderContinuous\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
|
- "\u001b[0;32m~/opt/anaconda3/lib/python3.8/site-packages/gymnasium/envs/box2d/bipedal_walker.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 23\u001b[0m )\n\u001b[1;32m 24\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m raise DependencyNotInstalled(\n\u001b[0m\u001b[1;32m 26\u001b[0m \u001b[0;34m\"Box2D is not installed, run `pip install gymnasium[box2d]`\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m ) from e\n",
|
|
|
- "\u001b[0;31mDependencyNotInstalled\u001b[0m: Box2D is not installed, run `pip install gymnasium[box2d]`"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "import gymnasium as gym\n",
|
|
|
- "\n",
|
|
|
- "env = gym.make(\"LunarLander-v2\", render_mode=\"human\")\n",
|
|
|
- "observation, info = env.reset()\n",
|
|
|
- "\n",
|
|
|
- "for _ in range(1000):\n",
|
|
|
- " action = env.action_space.sample() # agent policy that uses the observation and info\n",
|
|
|
- " observation, reward, terminated, truncated, info = env.step(action)\n",
|
|
|
- "\n",
|
|
|
- " if terminated or truncated:\n",
|
|
|
- " observation, info = env.reset()\n",
|
|
|
- "\n",
|
|
|
- "env.close()"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 2,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "<contextlib.ExitStack at 0x7fbffd7b04c0>"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 2,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "import gymnasium as gym\n",
|
|
|
- "import math\n",
|
|
|
- "import random\n",
|
|
|
- "import matplotlib\n",
|
|
|
- "import matplotlib.pyplot as plt\n",
|
|
|
- "env = gym.make(\"CartPole-v1\")\n",
|
|
|
- "\n",
|
|
|
- "# set up matplotlib\n",
|
|
|
- "is_ipython = 'inline' in matplotlib.get_backend()\n",
|
|
|
- "if is_ipython:\n",
|
|
|
- " from IPython import display\n",
|
|
|
- "\n",
|
|
|
- "plt.ion()"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 52,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "import gymnasium as gym\n",
|
|
|
- "import pygame\n",
|
|
|
- "\n",
|
|
|
- "env = gym.make(\"CartPole-v1\", render_mode=\"human\")\n",
|
|
|
- "observation, info = env.reset(seed=42)\n",
|
|
|
- "#for _ in range(100):\n",
|
|
|
- "# action = env.action_space.sample() # this is where you would insert your policy\n",
|
|
|
- "# observation, reward, terminated, truncated, info = env.step(action)\n",
|
|
|
- "# if terminated or truncated:\n",
|
|
|
- "# break\n",
|
|
|
- "#env.close()\n",
|
|
|
- "\n",
|
|
|
- "#pygame.display.update()"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 62,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "(array([ 0.20159529, 1.9464185 , -0.22034578, -2.9908078 ], dtype=float32),\n",
|
|
|
- " 1.0,\n",
|
|
|
- " True,\n",
|
|
|
- " False,\n",
|
|
|
- " {})"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 62,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "action = env.action_space.sample() # this is where you would insert your policy\n",
|
|
|
- "env.step(1)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 63,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "env.close()"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 64,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "0.075"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 64,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "0.3/4"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 163,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "m = [\n",
|
|
|
- " [0.775, 0.075, 0.075, 0.075],\n",
|
|
|
- " [0.075, 0.775, 0.075, 0.075],\n",
|
|
|
- " [0.075, 0.075, 0.775, 0.075],\n",
|
|
|
- " [0.075, 0.075, 0.075, 0.775]\n",
|
|
|
- "]"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 164,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "import torch\n",
|
|
|
- "m = torch.tensor(m)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 95,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([[1.],\n",
|
|
|
- " [1.],\n",
|
|
|
- " [1.],\n",
|
|
|
- " [1.]])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 95,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "m @ torch.ones(4, 1)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 192,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "r = torch.tensor([[2.0, 1.0, -1.0, -2.0]]).T\n",
|
|
|
- "v = torch.tensor([[200.0, 100.0, -100.0, -200.0]]).T\n",
|
|
|
- "gamma = 0.9"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 198,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "name": "stdout",
|
|
|
- "output_type": "stream",
|
|
|
- "text": [
|
|
|
- "tensor([[ 5.4054],\n",
|
|
|
- " [ 2.7027],\n",
|
|
|
- " [-2.7027],\n",
|
|
|
- " [-5.4054]])\n",
|
|
|
- "tensor([[ 5.4054],\n",
|
|
|
- " [ 2.7027],\n",
|
|
|
- " [-2.7027],\n",
|
|
|
- " [-5.4054]])\n",
|
|
|
- "tensor([[ 5.4054],\n",
|
|
|
- " [ 2.7027],\n",
|
|
|
- " [-2.7027],\n",
|
|
|
- " [-5.4054]])\n",
|
|
|
- "tensor([[ 5.4054],\n",
|
|
|
- " [ 2.7027],\n",
|
|
|
- " [-2.7027],\n",
|
|
|
- " [-5.4054]])\n",
|
|
|
- "tensor([[ 5.4054],\n",
|
|
|
- " [ 2.7027],\n",
|
|
|
- " [-2.7027],\n",
|
|
|
- " [-5.4054]])\n",
|
|
|
- "tensor([[ 5.4054],\n",
|
|
|
- " [ 2.7027],\n",
|
|
|
- " [-2.7027],\n",
|
|
|
- " [-5.4054]])\n",
|
|
|
- "tensor([[ 5.4054],\n",
|
|
|
- " [ 2.7027],\n",
|
|
|
- " [-2.7027],\n",
|
|
|
- " [-5.4054]])\n",
|
|
|
- "tensor([[ 5.4054],\n",
|
|
|
- " [ 2.7027],\n",
|
|
|
- " [-2.7027],\n",
|
|
|
- " [-5.4054]])\n",
|
|
|
- "tensor([[ 5.4054],\n",
|
|
|
- " [ 2.7027],\n",
|
|
|
- " [-2.7027],\n",
|
|
|
- " [-5.4054]])\n",
|
|
|
- "tensor([[ 5.4054],\n",
|
|
|
- " [ 2.7027],\n",
|
|
|
- " [-2.7027],\n",
|
|
|
- " [-5.4054]])\n"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "for i in range(10):\n",
|
|
|
- " print(v)\n",
|
|
|
- " v = gamma * m @ v + r"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 142,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "m = [\n",
|
|
|
- " [0.5, 0.5, 0.0],\n",
|
|
|
- " [0.25, 0.5, 0.25],\n",
|
|
|
- " [0, 0.5, 0.5]\n",
|
|
|
- "]"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 143,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "m = torch.tensor(m)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 160,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "n = torch.tensor([[0.0, 0.9, 0.1]])"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 2,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "ename": "NameError",
|
|
|
- "evalue": "name 'n' is not defined",
|
|
|
- "output_type": "error",
|
|
|
- "traceback": [
|
|
|
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
|
- "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
|
|
- "\u001b[0;32m<ipython-input-2-17b8a0a14bcd>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
|
- "\u001b[0;31mNameError\u001b[0m: name 'n' is not defined"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "for i in range(10):\n",
|
|
|
- " print(n)\n",
|
|
|
- " n = n @ m\n",
|
|
|
- " \n",
|
|
|
- " \n"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 51,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "import torch\n",
|
|
|
- "from torch.autograd import Variable"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 161,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": []
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 162,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": []
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 170,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "(tensor([[1., 0., 0.]], grad_fn=<AddBackward0>),\n",
|
|
|
- " tensor([[-0.2796, 0.1884, 0.0913]]))"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 170,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "logits = torch.tensor([[1., 1., 1.]], requires_grad=True)\n",
|
|
|
- "y = F.gumbel_softmax(logits, tau=1, hard=True)\n",
|
|
|
- "(y * torch.arange(3)).sum().backward()\n",
|
|
|
- "y, logits.grad"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 148,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([[1., 1.]], grad_fn=<AddBackward0>)"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 148,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "y_soft = F.gumbel_softmax(logits, tau=1, hard=False)\n",
|
|
|
- "y_soft.max(-1, keepdim=True)[1] - y_soft.detach() + y_soft"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": []
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": []
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": []
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": []
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": []
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 117,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([[1., 0.]], grad_fn=<AddBackward0>)"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 117,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "F.gumbel_softmax(logits, tau=1, hard=True)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 81,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "def sample_gumbel(shape, eps=1e-20):\n",
|
|
|
- " U = torch.rand(shape)\n",
|
|
|
- " return -Variable(torch.log(-torch.log(U + eps) + eps))\n",
|
|
|
- "\n",
|
|
|
- "def gumbel_softmax_sample(logits, temperature):\n",
|
|
|
- " y = logits + sample_gumbel(logits.size())\n",
|
|
|
- " return F.softmax(y / temperature, dim=-1)\n",
|
|
|
- "\n",
|
|
|
- "def gumbel_softmax(logits, temperature):\n",
|
|
|
- " \"\"\"\n",
|
|
|
- " ST-gumple-softmax\n",
|
|
|
- " input: [*, n_class]\n",
|
|
|
- " return: flatten --> [*, n_class] an one-hot vector\n",
|
|
|
- " \"\"\"\n",
|
|
|
- " y = gumbel_softmax_sample(logits, temperature)\n",
|
|
|
- " shape = y.size()\n",
|
|
|
- " _, ind = y.max(dim=-1)\n",
|
|
|
- " y_hard = torch.zeros_like(y).view(-1, shape[-1])\n",
|
|
|
- " y_hard.scatter_(1, ind.view(-1, 1), 1)\n",
|
|
|
- " y_hard = y_hard.view(*shape)\n",
|
|
|
- " y_hard = (y_hard - y).detach() + y\n",
|
|
|
- " return y_hard.view(-1,latent_dim*categorical_dim), y, ind"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 53,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "ename": "NameError",
|
|
|
- "evalue": "name 'latent_dim' is not defined",
|
|
|
- "output_type": "error",
|
|
|
- "traceback": [
|
|
|
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
|
- "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
|
|
- "\u001b[0;32m<ipython-input-53-56c036ce50bf>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mgumbel_softmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
|
- "\u001b[0;32m<ipython-input-52-69711d14ce77>\u001b[0m in \u001b[0;36mgumbel_softmax\u001b[0;34m(logits, temperature)\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0my_hard\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my_hard\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0my_hard\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0my_hard\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0my_hard\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mlatent_dim\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mcategorical_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
|
- "\u001b[0;31mNameError\u001b[0m: name 'latent_dim' is not defined"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "gumbel_softmax(torch.randn(4, 3), 0.1)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 55,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "latent_dim = 10\n",
|
|
|
- "categorical_dim = 3\n",
|
|
|
- "x = torch.randn(5, latent_dim, categorical_dim)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 85,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "x.requires_grad=True"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 86,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "y, k, ind = gumbel_softmax(x, 0.1)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 87,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([[1, 2, 2, 2, 2, 0, 2, 1, 2, 1],\n",
|
|
|
- " [0, 2, 0, 0, 2, 1, 0, 0, 0, 0],\n",
|
|
|
- " [0, 0, 2, 1, 1, 1, 1, 0, 0, 0],\n",
|
|
|
- " [1, 2, 1, 2, 0, 1, 1, 1, 0, 0],\n",
|
|
|
- " [2, 0, 0, 1, 0, 2, 0, 1, 0, 1]])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 87,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "ind"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 89,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "torch.return_types.max(\n",
|
|
|
- "values=tensor([[0.5745, 1.0000, 0.9985, 1.0000, 1.0000, 0.5934, 1.0000, 0.8247, 1.0000,\n",
|
|
|
- " 0.9999],\n",
|
|
|
- " [0.7558, 1.0000, 0.9999, 1.0000, 1.0000, 0.5294, 0.9994, 0.8727, 0.9999,\n",
|
|
|
- " 0.9213],\n",
|
|
|
- " [1.0000, 1.0000, 1.0000, 1.0000, 0.9942, 1.0000, 0.9996, 1.0000, 1.0000,\n",
|
|
|
- " 1.0000],\n",
|
|
|
- " [0.9993, 1.0000, 1.0000, 1.0000, 0.9998, 0.9998, 1.0000, 0.9661, 0.5564,\n",
|
|
|
- " 1.0000],\n",
|
|
|
- " [1.0000, 1.0000, 0.9999, 1.0000, 0.9785, 0.8706, 1.0000, 1.0000, 1.0000,\n",
|
|
|
- " 0.5304]], grad_fn=<MaxBackward0>),\n",
|
|
|
- "indices=tensor([[1, 2, 2, 2, 2, 0, 2, 1, 2, 1],\n",
|
|
|
- " [0, 2, 0, 0, 2, 1, 0, 0, 0, 0],\n",
|
|
|
- " [0, 0, 2, 1, 1, 1, 1, 0, 0, 0],\n",
|
|
|
- " [1, 2, 1, 2, 0, 1, 1, 1, 0, 0],\n",
|
|
|
- " [2, 0, 0, 1, 0, 2, 0, 1, 0, 1]]))"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 89,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "k.max(dim=-1)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 94,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([1, 2, 2, 2, 2, 0, 2, 1, 2, 1])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 94,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "k.max(dim=-1)[1][0]"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 92,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "torch.Size([10])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 92,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "torch.arange(10).shape"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 171,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([[1., 0., 0.]], grad_fn=<AddBackward0>)"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 171,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "y"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 107,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "ename": "NameError",
|
|
|
- "evalue": "name 'nn' is not defined",
|
|
|
- "output_type": "error",
|
|
|
- "traceback": [
|
|
|
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
|
- "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
|
|
- "\u001b[0;32m<ipython-input-107-e31dcaa6664b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0ml\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
|
|
- "\u001b[0;31mNameError\u001b[0m: name 'nn' is not defined"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "l = nn.embedding(3, 4)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 172,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "ename": "NameError",
|
|
|
- "evalue": "name 'nn' is not defined",
|
|
|
- "output_type": "error",
|
|
|
- "traceback": [
|
|
|
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
|
- "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
|
|
- "\u001b[0;32m<ipython-input-172-23c43f10cdd2>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mclass\u001b[0m \u001b[0mRewardModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mModule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membedding\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
|
- "\u001b[0;31mNameError\u001b[0m: name 'nn' is not defined"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "class RewardModel(nn.Module):\n",
|
|
|
- "\n",
|
|
|
- " def __init__(self, model):\n",
|
|
|
- " super().__init__()\n",
|
|
|
- " self.embedding = model\n",
|
|
|
- " self.score = nn.Linear(model.embed_dim, 1, bias=False)\n",
|
|
|
- "\n",
|
|
|
- " def forward(self, x, seq_len=None):\n",
|
|
|
- " # x:表示文本,形状(B, T, C), seq_len:表示文本长度,形状(B)\n",
|
|
|
- " B, T = x.shape\n",
|
|
|
- " emb = self.embedding(x).last_hidden_state # (B, T, C)\n",
|
|
|
- " ind = torch.arange(B, device=x.device)\n",
|
|
|
- " if seq_len == None:\n",
|
|
|
- " seq_len = torch.tensor([T] * B)\n",
|
|
|
- " # 获取最后一个词元的特征\n",
|
|
|
- " pooled_emb = emb[ind, seq_len - 1] # (B, C)\n",
|
|
|
- " score = self.score(pooled_emb) # (B, 1)\n",
|
|
|
- " return score\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "r_model = RewardModel(GPT2Model.from_pretrained('gpt2'))"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 175,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "x = torch.randn(1, 4, requires_grad=True)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 176,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor(1.7199, grad_fn=<MaxBackward1>)"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 176,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "torch.max(x)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 181,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([[1]])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 181,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "logits = torch.randn(1, 5, requires_grad=True)\n",
|
|
|
- "probs = F.softmax(logits, dim=-1)\n",
|
|
|
- "y = torch.multinomial(probs, num_samples=1)\n",
|
|
|
- "y"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 182,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([[1., 0., 0., 0., 0.]], grad_fn=<AddBackward0>)"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 182,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "y_one_hot = F.gumbel_softmax(logits, hard=True)\n",
|
|
|
- "gumbel_y = torch.argmax(y_one_hot, dim=-1, keepdim=True)\n",
|
|
|
- "y_one_hot"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 183,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([[0]])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 183,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "gumbel_y"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 189,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "def log(t, eps = 1e-20):\n",
|
|
|
- " return torch.log(t.clamp(min = eps))\n",
|
|
|
- "\n",
|
|
|
- "def gumbel_noise(t):\n",
|
|
|
- " noise = torch.zeros_like(t).uniform_(0, 1)\n",
|
|
|
- " return -log(-log(noise))\n",
|
|
|
- "\n",
|
|
|
- "def gumbel_sample(t, temperature = 1., dim = -1):\n",
|
|
|
- " return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 190,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([2])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 190,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "gumbel_sample(torch.randn(1, 4, requires_grad=True))"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 191,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "B, T, vs = (3, 4, 20)\n",
|
|
|
- "logits = torch.randn(B, T, vs)\n",
|
|
|
- "labels = torch.randint(vs, (B, T))"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 197,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "(tensor(3.9027), tensor(3.9027))"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 197,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "l = F.cross_entropy(logits.view(B * T, vs), labels.view(B * T), reduction='none')\n",
|
|
|
- "l.mean(), F.cross_entropy(logits.view(B * T, vs), labels.view(B * T))"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 203,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([[-4.3894, -4.6075, -2.8143, -3.5261],\n",
|
|
|
- " [-4.3982, -3.5565, -2.3154, -4.5414],\n",
|
|
|
- " [-4.1661, -3.0418, -2.9297, -6.5461]])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 203,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "lnP = -F.cross_entropy(logits.transpose(-2, -1), labels, reduction='none')\n"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 221,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "lossi = F.cross_entropy(logits.transpose(-2, -1), labels, reduction='none')"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 225,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor(3.9027)"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 225,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "lossi.mean(-1).mean(-1)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "tokenizer"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": []
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": []
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 220,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor(3.9027)"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 220,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "(-r * lnP).mean()"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 211,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([-15.3374, -14.8116, -16.6836])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 211,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "lnP.sum(-1)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 213,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([[0, 0, 0, 0, 1, 1, 1, 0, 1, 1],\n",
|
|
|
- " [0, 0, 1, 0, 1, 1, 0, 0, 1, 0],\n",
|
|
|
- " [0, 0, 1, 1, 1, 1, 1, 1, 1, 0],\n",
|
|
|
- " [0, 1, 1, 1, 1, 1, 0, 0, 0, 1],\n",
|
|
|
- " [1, 0, 1, 1, 1, 1, 1, 1, 0, 0],\n",
|
|
|
- " [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],\n",
|
|
|
- " [1, 0, 0, 0, 0, 1, 1, 1, 0, 0],\n",
|
|
|
- " [0, 1, 1, 0, 0, 0, 0, 1, 1, 0],\n",
|
|
|
- " [0, 1, 1, 1, 1, 1, 0, 0, 1, 0],\n",
|
|
|
- " [1, 1, 1, 1, 0, 0, 0, 0, 1, 1],\n",
|
|
|
- " [0, 1, 0, 1, 1, 1, 1, 1, 0, 1],\n",
|
|
|
- " [0, 1, 0, 0, 1, 1, 0, 1, 0, 1],\n",
|
|
|
- " [0, 0, 1, 1, 0, 1, 0, 0, 0, 0],\n",
|
|
|
- " [0, 0, 1, 1, 1, 0, 0, 0, 1, 0],\n",
|
|
|
- " [0, 1, 1, 0, 0, 1, 0, 1, 1, 0],\n",
|
|
|
- " [0, 0, 0, 0, 0, 1, 1, 1, 1, 0],\n",
|
|
|
- " [1, 0, 0, 1, 0, 0, 1, 0, 1, 1],\n",
|
|
|
- " [1, 1, 1, 0, 0, 0, 0, 0, 0, 1],\n",
|
|
|
- " [1, 1, 1, 0, 1, 1, 1, 1, 0, 0],\n",
|
|
|
- " [0, 0, 0, 1, 1, 1, 1, 1, 0, 1]])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 213,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "torch.randint(2, (20, 10))"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 2,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "import torch"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 23,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "a = torch.tensor([0.5, 0.2, 0.1])\n",
|
|
|
- "r = torch.tensor([1001., 1002., 1003.])"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 28,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "(tensor([500.5000, 200.4000, 100.3000]), tensor(208.2627))"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 28,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "a * r, (a * r).std()"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 29,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "(tensor([0.0000, 0.2000, 0.2000]), tensor(0.1155))"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 29,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "a * (r-1001), (a * (r-1001)).std()"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 22,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([0., 1., 2.])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 22,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "r-1001"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 34,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor(208.1666)"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 34,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "(1000 * a).std()"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 35,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor(0.2082)"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 35,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "(1 * a).std()"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 38,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([1001.0000, 400.4000, 200.2000])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 38,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "1000 * a + 1002 * a"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 39,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([1.5000, 0.6000, 0.3000])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 39,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "a + 2 * a"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 64,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "torch.Size([2, 20])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 64,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "g = torch.normal(1, 1, (2, 20))\n",
|
|
|
- "g.shape"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 78,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "torch.Size([2, 1])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 78,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "r = torch.normal(1, 1, (2, 1))\n",
|
|
|
- "r.shape"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 79,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([[ 0.9881, 0.7762, 1.3812, 0.7958, 1.0094, 0.1070, 0.1709, 0.3961,\n",
|
|
|
- " -0.0118, 0.2604, 0.5242, 0.2400, 0.1989, 0.2737, 1.4879, 0.4091,\n",
|
|
|
- " 0.7199, -0.2781, 0.7418, 1.2079],\n",
|
|
|
- " [ 3.4830, 0.1242, 0.8680, 0.0464, 0.8523, 1.0290, -0.3891, 2.5647,\n",
|
|
|
- " -0.1333, 3.3386, 0.1980, 0.2184, 0.4268, 4.9361, 3.0060, 1.3053,\n",
|
|
|
- " 2.5407, 1.2318, 1.8263, 1.0207]])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 79,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "help(g * r)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 86,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "name": "stdout",
|
|
|
- "output_type": "stream",
|
|
|
- "text": [
|
|
|
- "Help on method norm in module torch._tensor:\n",
|
|
|
- "\n",
|
|
|
- "norm(p: Union[float, str, NoneType] = 'fro', dim=None, keepdim=False, dtype=None) method of torch.Tensor instance\n",
|
|
|
- " See :func:`torch.norm`\n",
|
|
|
- "\n"
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "help(g.norm)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 59,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor(1.4053)"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 59,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "(g * rr).std()"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 88,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([39.9750, 37.7648])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 88,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "(g**2).sum(dim=-1)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 179,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "def norm_(g):\n",
|
|
|
- " k = g / g.norm(dim=-1, keepdim=True)\n",
|
|
|
- " return k"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 1,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "import torch"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 2,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor(99.9496)"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 2,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "num = 10000\n",
|
|
|
- "grad = torch.normal(0, 1, (num, 100))\n",
|
|
|
- "g = torch.normal(100, 1, (num, 1))\n",
|
|
|
- "(grad * g).std()"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 3,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor(1.0004)"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 3,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "num = 10000\n",
|
|
|
- "grad = torch.normal(0, 1, (num, 100))\n",
|
|
|
- "g = torch.normal(100, 1, (num, 1)) - 100\n",
|
|
|
- "(grad * g).std()"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 261,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor(0.0996)"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 261,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "(g * r).mean(-1).std()"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": 241,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/plain": [
|
|
|
- "tensor([[100.],\n",
|
|
|
- " [100.],\n",
|
|
|
- " [100.],\n",
|
|
|
- " [100.],\n",
|
|
|
- " [100.],\n",
|
|
|
- " [100.],\n",
|
|
|
- " [100.],\n",
|
|
|
- " [100.],\n",
|
|
|
- " [100.],\n",
|
|
|
- " [100.]])"
|
|
|
- ]
|
|
|
- },
|
|
|
- "execution_count": 241,
|
|
|
- "metadata": {},
|
|
|
- "output_type": "execute_result"
|
|
|
- }
|
|
|
- ],
|
|
|
- "source": [
|
|
|
- "torch.normal(100, 0, (10, 1))"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": []
|
|
|
- }
|
|
|
- ],
|
|
|
- "metadata": {
|
|
|
- "kernelspec": {
|
|
|
- "display_name": "Python 3",
|
|
|
- "language": "python",
|
|
|
- "name": "python3"
|
|
|
- },
|
|
|
- "language_info": {
|
|
|
- "codemirror_mode": {
|
|
|
- "name": "ipython",
|
|
|
- "version": 3
|
|
|
- },
|
|
|
- "file_extension": ".py",
|
|
|
- "mimetype": "text/x-python",
|
|
|
- "name": "python",
|
|
|
- "nbconvert_exporter": "python",
|
|
|
- "pygments_lexer": "ipython3",
|
|
|
- "version": "3.8.5"
|
|
|
- }
|
|
|
- },
|
|
|
- "nbformat": 4,
|
|
|
- "nbformat_minor": 4
|
|
|
-}
|