{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n", "\u001b[0mRequirement already satisfied: torcheval in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (0.0.6)\n", "Requirement already satisfied: torchtnt>=0.0.5 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torcheval) (0.2.0)\n", "Requirement already satisfied: typing-extensions in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torcheval) (3.7.4.3)\n", "Requirement already satisfied: torch in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torchtnt>=0.0.5->torcheval) (2.0.1)\n", "Requirement already satisfied: numpy in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torchtnt>=0.0.5->torcheval) (1.19.2)\n", "Requirement already satisfied: fsspec in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torchtnt>=0.0.5->torcheval) (2023.6.0)\n", "Requirement already satisfied: tensorboard in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torchtnt>=0.0.5->torcheval) (2.14.0)\n", "Requirement already satisfied: packaging in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torchtnt>=0.0.5->torcheval) (20.4)\n", "Requirement already satisfied: psutil in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torchtnt>=0.0.5->torcheval) (5.7.2)\n", "Requirement already satisfied: pyre-extensions in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torchtnt>=0.0.5->torcheval) (0.0.30)\n", "Requirement already satisfied: setuptools in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torchtnt>=0.0.5->torcheval) (68.1.2)\n", "Requirement already satisfied: tqdm in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torchtnt>=0.0.5->torcheval) (4.65.0)\n", "Requirement already satisfied: pyparsing>=2.0.2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from packaging->torchtnt>=0.0.5->torcheval) (2.4.7)\n", "Requirement already satisfied: six in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from packaging->torchtnt>=0.0.5->torcheval) (1.15.0)\n", "Requirement already satisfied: typing-inspect in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from pyre-extensions->torchtnt>=0.0.5->torcheval) (0.9.0)\n", "Requirement already satisfied: absl-py>=0.4 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from tensorboard->torchtnt>=0.0.5->torcheval) (1.4.0)\n", "Requirement already satisfied: grpcio>=1.48.2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from tensorboard->torchtnt>=0.0.5->torcheval) (1.57.0)\n", "Requirement already satisfied: google-auth<3,>=1.6.3 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from tensorboard->torchtnt>=0.0.5->torcheval) (2.22.0)\n", "Requirement already satisfied: google-auth-oauthlib<1.1,>=0.5 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from tensorboard->torchtnt>=0.0.5->torcheval) (1.0.0)\n", "Requirement already satisfied: markdown>=2.6.8 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from tensorboard->torchtnt>=0.0.5->torcheval) (3.4.4)\n", "Requirement already satisfied: protobuf>=3.19.6 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from tensorboard->torchtnt>=0.0.5->torcheval) (4.24.1)\n", "Requirement already satisfied: requests<3,>=2.21.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from tensorboard->torchtnt>=0.0.5->torcheval) (2.24.0)\n", "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from tensorboard->torchtnt>=0.0.5->torcheval) (0.7.1)\n", "Requirement already satisfied: werkzeug>=1.0.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from tensorboard->torchtnt>=0.0.5->torcheval) (1.0.1)\n", "Requirement already satisfied: wheel>=0.26 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from tensorboard->torchtnt>=0.0.5->torcheval) (0.35.1)\n", "Requirement already satisfied: filelock in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch->torchtnt>=0.0.5->torcheval) (3.0.12)\n", "Requirement already satisfied: sympy in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch->torchtnt>=0.0.5->torcheval) (1.6.2)\n", "Requirement already satisfied: networkx in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch->torchtnt>=0.0.5->torcheval) (2.5)\n", "Requirement already satisfied: jinja2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from torch->torchtnt>=0.0.5->torcheval) (2.11.2)\n", "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard->torchtnt>=0.0.5->torcheval) (5.3.1)\n", "Requirement already satisfied: pyasn1-modules>=0.2.1 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard->torchtnt>=0.0.5->torcheval) (0.3.0)\n", "Requirement already satisfied: rsa<5,>=3.1.4 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard->torchtnt>=0.0.5->torcheval) (4.9)\n", "Requirement already satisfied: urllib3<2.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard->torchtnt>=0.0.5->torcheval) (1.25.11)\n", "Requirement already satisfied: requests-oauthlib>=0.7.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard->torchtnt>=0.0.5->torcheval) (1.3.1)\n", "Requirement already satisfied: importlib-metadata>=4.4 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from markdown>=2.6.8->tensorboard->torchtnt>=0.0.5->torcheval) (6.8.0)\n", "Requirement already satisfied: chardet<4,>=3.0.2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests<3,>=2.21.0->tensorboard->torchtnt>=0.0.5->torcheval) (3.0.4)\n", "Requirement already satisfied: idna<3,>=2.5 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests<3,>=2.21.0->tensorboard->torchtnt>=0.0.5->torcheval) (2.10)\n", "Requirement already satisfied: certifi>=2017.4.17 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests<3,>=2.21.0->tensorboard->torchtnt>=0.0.5->torcheval) (2020.6.20)\n", "Requirement already satisfied: MarkupSafe>=0.23 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from jinja2->torch->torchtnt>=0.0.5->torcheval) (1.1.1)\n", "Requirement already satisfied: decorator>=4.3.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from networkx->torch->torchtnt>=0.0.5->torcheval) (4.4.2)\n", "Requirement already satisfied: mpmath>=0.19 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from sympy->torch->torchtnt>=0.0.5->torcheval) (1.1.0)\n", "Requirement already satisfied: mypy-extensions>=0.3.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from typing-inspect->pyre-extensions->torchtnt>=0.0.5->torcheval) (1.0.0)\n", "Requirement already satisfied: zipp>=0.5 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard->torchtnt>=0.0.5->torcheval) (3.4.0)\n", "Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard->torchtnt>=0.0.5->torcheval) (0.5.0)\n", "Requirement already satisfied: oauthlib>=3.0.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard->torchtnt>=0.0.5->torcheval) (3.2.2)\n", "\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33mDEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!pip install torcheval" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import os\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "\n", "def read_data(path):\n", " \"\"\"\n", " 使用pandas读取数据\n", " \"\"\"\n", " data = pd.read_csv(path)\n", " cols = [\"age\", \"education_num\", \"capital_gain\", \"capital_loss\", \"hours_per_week\", \"label\"]\n", " return data[cols]\n", "\n", "\n", "if os.name == \"nt\":\n", " data_path = \".\\\\data\\\\adult.data\"\n", "else:\n", " data_path = \"./data/adult.data\"\n", "data = read_data(data_path)\n", "data[\"label_code\"] = pd.Categorical(data.label).codes" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageeducation_numcapital_gaincapital_losshours_per_weeklabellabel_code
039132174040<=50K0
150130013<=50K0
23890040<=50K0
35370040<=50K0
428130040<=50K0
........................
3255627120038<=50K0
325574090040>50K1
325585890040<=50K0
325592290020<=50K0
3256052915024040>50K1
\n", "

32561 rows × 7 columns

\n", "
" ], "text/plain": [ " age education_num capital_gain capital_loss hours_per_week label \\\n", "0 39 13 2174 0 40 <=50K \n", "1 50 13 0 0 13 <=50K \n", "2 38 9 0 0 40 <=50K \n", "3 53 7 0 0 40 <=50K \n", "4 28 13 0 0 40 <=50K \n", "... ... ... ... ... ... ... \n", "32556 27 12 0 0 38 <=50K \n", "32557 40 9 0 0 40 >50K \n", "32558 58 9 0 0 40 <=50K \n", "32559 22 9 0 0 20 <=50K \n", "32560 52 9 15024 0 40 >50K \n", "\n", " label_code \n", "0 0 \n", "1 0 \n", "2 0 \n", "3 0 \n", "4 0 \n", "... ... \n", "32556 0 \n", "32557 1 \n", "32558 0 \n", "32559 0 \n", "32560 1 \n", "\n", "[32561 rows x 7 columns]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 展示数据\n", "data" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn.functional as F\n", "from utils import Linear\n", "\n", "\n", "torch.manual_seed(1024)\n", "\n", "class LogitRegression:\n", " \n", " def __init__(self, neg, pos):\n", " self.pos = pos\n", " self.neg = neg\n", " \n", " def __call__(self, x):\n", " self.out = torch.concat((self.neg(x), self.pos(x)), dim=1)\n", " return self.out\n", " \n", " def parameters(self):\n", " return self.neg.parameters() + self.pos.parameters()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# 定义模型\n", "pos = Linear(5, 1)\n", "neg = Linear(5, 1)\n", "model = LogitRegression(neg, pos)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# 准备数据\n", "x = torch.tensor(data[['age', 'education_num', 'capital_gain', 'capital_loss', 'hours_per_week']].values).float()\n", "x = F.normalize(x) # shape[32561, 5]\n", "y = torch.tensor(data['label_code']).long() # shape[32561]" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[ 1.2665, -1.7305]]), tensor([[0.9524, 0.0476]]), tensor([0]))" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 使用模型\n", "## 注意,模型输入数据的形状一定要是[n, 2]\n", "logits = model(x[[1]]) # shape: [1, 2]\n", "probs = F.softmax(logits, dim=1) # shape: [1, 2]\n", "pred = torch.where(probs[:, 1] > 0.5, 1, 0)\n", "logits, probs, pred" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.0487), tensor(0.0487))" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 计算模型在单点的损失\n", "loss = F.cross_entropy(logits, y[[1]])\n", "# cross_entropy的具体实现过程\n", "-probs[torch.arange(1), y[[1]]].log().mean(), loss" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# 对于模型参数,需要记录它们的梯度(为反向传播做准备)\n", "for p in model.parameters():\n", " p.requires_grad = True" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "step 0/20000, loss: 0.6580\n", "step 2000/20000, loss: 0.5092\n", "step 4000/20000, loss: 0.5066\n", "step 6000/20000, loss: 0.5037\n", "step 8000/20000, loss: 0.4958\n", "step 10000/20000, loss: 0.5046\n", "step 12000/20000, loss: 0.5015\n", "step 14000/20000, loss: 0.4952\n", "step 16000/20000, loss: 0.5086\n", "step 18000/20000, loss: 0.5086\n" ] } ], "source": [ "# 标准随机梯度下降法的超参数\n", "max_steps = 20000\n", "batch_size = 3000\n", "lossi = []\n", "\n", "for i in range(max_steps):\n", " # 构造批次训练数据\n", " ix = torch.randint(0, x.shape[0], (batch_size,))\n", " xb = x[ix]\n", " yb = y[ix]\n", " \n", " # 向前传播\n", " logits = model(xb)\n", " loss = F.cross_entropy(logits, yb)\n", " # 反向传播\n", " loss.backward()\n", " \n", " # 更新模型参数\n", " ## 学习速率衰减\n", " learning_rate = 0.1 if i < 10000 else 0.01\n", " with torch.no_grad():\n", " for p in model.parameters():\n", " p -= learning_rate * p.grad\n", " p.grad = None\n", " \n", " # 统计数据\n", " if i % 2000 == 0:\n", " print(f'step {i: 6d}/{max_steps}, loss: {loss.item(): .4f}')\n", " lossi.append(loss.item())" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 展示模型损失优化的过程\n", "plt.plot(torch.tensor(lossi))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([32561, 2]), torch.Size([32561, 2]), torch.Size([32561]))" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 关闭梯度追踪\n", "with torch.no_grad():\n", " logits = model(x) \n", " probs = F.softmax(logits, dim=1)\n", " pred = torch.where(probs[:, 1] > 0.5, 1, 0)\n", "logits.shape, probs.shape, pred.shape" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.2928), tensor(0.5902), tensor(0.3914))" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from torcheval.metrics.functional.classification import binary_recall\n", "from torcheval.metrics.functional import binary_precision, binary_f1_score\n", "binary_recall(pred, y), binary_precision(pred, y), binary_f1_score(pred, y)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# 通过排序得到对偏好估计\n", "class PreferenceModel:\n", " \n", " def __init__(self, pref):\n", " self.pref = pref\n", " \n", " def __call__(self, x0, x1):\n", " self.out = torch.concat((self.pref(x0), self.pref(x1)), dim=1)\n", " return self.out\n", " \n", " def parameters(self):\n", " return self.pref.parameters()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "preference = Linear(5, 1)\n", "p_model = PreferenceModel(preference)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[0.0679, 1.1044]]), tensor([[0.2618, 0.7382]]))" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x0 = x[[0]]\n", "x1 = x[[1]]\n", "p_logits = p_model(x0, x1)\n", "p_probs = F.softmax(p_logits, dim=1)\n", "p_logits, p_probs" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }