{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 安装第三方库\n", "!pip install torcheval" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torch.nn.functional as F\n", "from utils import Linear\n", "import pandas as pd\n", "import os\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "# 固定随机数生成种子,使得计算结果可以复现\n", "torch.manual_seed(1024)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def read_data(path):\n", " \"\"\"\n", " 使用pandas读取数据\n", " \"\"\"\n", " data = pd.read_csv(path)\n", " cols = [\"age\", \"education_num\", \"capital_gain\", \"capital_loss\", \"hours_per_week\", \"label\"]\n", " return data[cols]\n", "\n", "\n", "if os.name == \"nt\":\n", " data_path = \".\\\\data\\\\adult.data\"\n", "else:\n", " data_path = \"./data/adult.data\"\n", "data = read_data(data_path)\n", "data[\"label_code\"] = pd.Categorical(data.label).codes" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageeducation_numcapital_gaincapital_losshours_per_weeklabellabel_code
039132174040<=50K0
150130013<=50K0
23890040<=50K0
35370040<=50K0
428130040<=50K0
........................
3255627120038<=50K0
325574090040>50K1
325585890040<=50K0
325592290020<=50K0
3256052915024040>50K1
\n", "

32561 rows × 7 columns

\n", "
" ], "text/plain": [ " age education_num capital_gain capital_loss hours_per_week label \\\n", "0 39 13 2174 0 40 <=50K \n", "1 50 13 0 0 13 <=50K \n", "2 38 9 0 0 40 <=50K \n", "3 53 7 0 0 40 <=50K \n", "4 28 13 0 0 40 <=50K \n", "... ... ... ... ... ... ... \n", "32556 27 12 0 0 38 <=50K \n", "32557 40 9 0 0 40 >50K \n", "32558 58 9 0 0 40 <=50K \n", "32559 22 9 0 0 20 <=50K \n", "32560 52 9 15024 0 40 >50K \n", "\n", " label_code \n", "0 0 \n", "1 0 \n", "2 0 \n", "3 0 \n", "4 0 \n", "... ... \n", "32556 0 \n", "32557 1 \n", "32558 0 \n", "32559 0 \n", "32560 1 \n", "\n", "[32561 rows x 7 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 展示数据\n", "data" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class LogitRegression:\n", " \n", " def __init__(self, neg, pos):\n", " '''\n", " 定义逻辑回归模型的结构\n", " 参数\n", " ----\n", " neg :Linear,负面的偏好,模型的形状为(k, 1)\n", " pos :Linear,正面的偏好,模型的形状为(k, 1)\n", " '''\n", " self.pos = pos\n", " self.neg = neg\n", " \n", " def __call__(self, x):\n", " '''\n", " 逻辑回归模型的向前传播\n", " 参数\n", " ----\n", " x :torch.FloatTensor,形状为(n, k),其中n表示批量数据的大小,k表示特征的个数\n", " '''\n", " self.out = torch.concat((self.neg(x), self.pos(x)), dim=1)\n", " return self.out # (n, 2)\n", " \n", " def parameters(self):\n", " return self.neg.parameters() + self.pos.parameters()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# 定义模型\n", "pos = Linear(5, 1)\n", "neg = Linear(5, 1)\n", "model = LogitRegression(neg, pos)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# 准备数据\n", "x = torch.tensor(data[['age', 'education_num', 'capital_gain', 'capital_loss', 'hours_per_week']].values).float()\n", "x = F.normalize(x) # (32561, 5)\n", "y = torch.tensor(data['label_code']).long() # (32561)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[ 1.2665, -1.7305]]), tensor([[0.9524, 0.0476]]), tensor([0]))" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 使用模型\n", "## 注意,模型输入数据的形状一定要是(n, 5)\n", "logits = model(x[[1]]) # (1, 2)\n", "probs = F.softmax(logits, dim=1) # (1, 2)\n", "pred = torch.where(probs[:, 1] > 0.5, 1, 0)\n", "logits, probs, pred" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.0487), tensor(0.0487))" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 计算模型在单点的损失\n", "loss = F.cross_entropy(logits, y[[1]])\n", "# cross_entropy的具体实现过程\n", "-probs[torch.arange(1), y[[1]]].log().mean(), loss" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# 对于模型参数,需要记录它们的梯度(为反向传播做准备)\n", "for p in model.parameters():\n", " p.requires_grad = True" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "step 0/20000, loss: 0.6580\n", "step 2000/20000, loss: 0.5092\n", "step 4000/20000, loss: 0.5066\n", "step 6000/20000, loss: 0.5037\n", "step 8000/20000, loss: 0.4958\n", "step 10000/20000, loss: 0.5046\n", "step 12000/20000, loss: 0.5015\n", "step 14000/20000, loss: 0.4952\n", "step 16000/20000, loss: 0.5086\n", "step 18000/20000, loss: 0.5086\n" ] } ], "source": [ "# 标准随机梯度下降法的超参数\n", "max_steps = 20000\n", "batch_size = 3000\n", "lossi = []\n", "\n", "for i in range(max_steps):\n", " # 构造批次训练数据\n", " ix = torch.randint(0, x.shape[0], (batch_size,))\n", " xb = x[ix]\n", " yb = y[ix]\n", " # 向前传播\n", " logits = model(xb)\n", " loss = F.cross_entropy(logits, yb)\n", " # 反向传播\n", " loss.backward()\n", " # 更新模型参数\n", " ## 学习速率衰减\n", " learning_rate = 0.1 if i < 10000 else 0.01\n", " with torch.no_grad():\n", " for p in model.parameters():\n", " p -= learning_rate * p.grad\n", " p.grad = None\n", " \n", " # 统计数据\n", " if i % 2000 == 0:\n", " print(f'step {i: 6d}/{max_steps}, loss: {loss.item(): .4f}')\n", " lossi.append(loss.item())" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/OQEPoAAAACXBIWXMAAAsTAAALEwEAmpwYAAAsqklEQVR4nO3dd3wUdfoH8M9DQu+9Q+gIIi0gAlJUuid24c6CDVHxTj31UNGznlhQf5yoZwErRVARBaQpVVpC74QQQkILCU1a2vf3x8xuZndndmeT3ewyfN6vV17ZnZ2deXZ25plvm1lRSoGIiJyrRKQDICKi8GKiJyJyOCZ6IiKHY6InInI4JnoiIoeLjXQA3mrUqKHi4uIiHQYR0UUlMTHxmFKqptlrUZfo4+LikJCQEOkwiIguKiKy3+o1Nt0QETkcEz0RkcMx0RMRORwTPRGRwzHRExE5HBM9EZHDMdETETmcoxL9/G2HcfT0+UiHQUQUVRyT6M/n5OGhrxNx52drIh0KEVFUcUyiz9d/QOVA1rkIR0JEFF0ck+iJiMic4xK9An8akYjIyDGJXiCRDoGIKCo5JtETEZE5xyV6xZYbIiIPjkn0wpYbIiJTjkn0LMkTEZlzTKJ3YcmeiMiT4xI9S/ZERJ4ck+hZkiciMueYRE9EROYcl+jZckNE5MlxiT47Nz/SIRARRRXHJPrUrLORDoGIKCo5JtFztA0RkTnnJHq2zhMRmXJMos9n0zwRkSnHJPpq5UtFOgQioqjkmERfghdMERGZckyi5++OEBGZc06iJyIiU7YSvYgMFJFdIpIkImMs5rldRLaLyDYRmWKYniciG/W/2aEKnIiI7IkNNIOIxACYCKAfgDQA60RktlJqu2GeFgCeBdBDKXVcRGoZFnFOKdUhtGETEZFddkr0XQEkKaWSlVLZAKYBGOo1z4MAJiqljgOAUupoaMMkIqLCspPo6wM4YHiepk8zagmgpYisFJHVIjLQ8FoZEUnQp99otgIRGanPk5CRkRFM/EREFEDAppsgltMCQB8ADQAsE5F2SqkTABorpdJFpCmA30Rki1Jqr/HNSqlPAHwCAPHx8bzElYgohOyU6NMBNDQ8b6BPM0oDMFsplaOU2gdgN7TED6VUuv4/GcASAB2LGDMREQXBTqJfB6CFiDQRkVIAhgHwHj0zC1ppHiJSA1pTTrKIVBWR0obpPQBsBxERFZuATTdKqVwRGQ1gPoAYAJOUUttE5BUACUqp2fpr/UVkO4A8AE8rpTJFpDuA/4lIPrSTyjjjaB0iIgo/W230Sqm5AOZ6TXvR8FgBeFL/M87zB4B2RQ8zMOGlsUREpnhlLBGRwzHRExE5HBM9EZHDMdETETmcYxJ9jQr84REiIjOOSfQiHHVDRGTGMYmeiIjMMdETETkcEz0RkcMx0RMRORwTPRGRwzHRExE5HBM9EZHDMdETETkcEz0RkcMx0RMRORwTPRGRwzHRExE5HBM9EZHDMdETETkcEz0RkcMx0RMRORwTPRGRwzHRExE5HBM9EZHDMdETETkcEz0RkcMx0RMRORwTPRGRwzHRExE5HBM9EZHDOTLRn7mQG+kQiIiihiMT/VMzNkU6BCKiqOHIRL8+9XikQyAiihqOTPRERFTAkYleqUhHQEQUPZyZ6CMdABFRFHFmomemJyJys5XoRWSgiOwSkSQRGWMxz+0isl1EtonIFMP0e0Rkj/53T6gC94+ZnojIJTbQDCISA2AigH4A0gCsE5HZSqnthnlaAHgWQA+l1HERqaVPrwbg3wDioWXfRP29HBZDRFRM7JTouwJIUkolK6WyAUwDMNRrngcBTHQlcKXUUX36AAALlVJZ+msLAQwMTejW2HRDRFTATqKvD+CA4XmaPs2oJYCWIrJSRFaLyMAg3gsRGSkiCSKSkJGRYT96IiIKKFSdsbEAWgDoA2A4gE9FpIrdNyulPlFKxSul4mvWrFnkYFigJyIqYCfRpwNoaHjeQJ9mlAZgtlIqRym1D8BuaInfzntDLutMdrhXQUR00bCT6NcBaCEiTUSkFIBhAGZ7zTMLWmkeIlIDWlNOMoD5APqLSFURqQqgvz6NiIiKScBRN0qpXBEZDS1BxwCYpJTaJiKvAEhQSs1GQULfDiAPwNNKqUwAEJFXoZ0sAOAVpVRWOD4IERGZC5joAUApNRfAXK9pLxoeKwBP6n/e750EYFLRwiQiosJy5JWxRERUgImeiMjhmOiJiByOiZ6IyOGY6ImIHI6JnojI4ZjoiYgcjomeiMjhHJvoT57NiXQIRERRwbGJ/tiZC5EOgYgoKjg20RMRkcaxiV4iHQARUZRwbqIXpnoiIsDBiZ6IiDRM9EREDueoRF+lXEn34zmbD0YwEiKi6OGoRJ+XV/Cz4CuTMiMYCRFR9HBUoq9SvmTgmYiILjGOSvR1K5d1P+agGyIijaMS/VP9W0U6BCKiqOOoRF861lEfh4goJBybGdl0Q0SkcW6i500QiIgAOCzRG0vxK5KORS4QIqIo4qhET0REvhyV6NlcQ0Tky1GJnoiIfDHRExE5nKMSPYdUEhH5clSiJyIiX0z0REQO56hEf1ndSpEOgYgo6jgq0ceUYCM9EZE3RyV6b6fO50Q6BCKiiHN0or//i3WRDoGIKOIcnejXpRyPdAi25OblQykVeEYiokJwdKK/GOTm5aP58/Pw+pwdkQ6FiByKiT7CcvO1kvzXq/dHOBIicipbiV5EBorILhFJEpExJq+PEJEMEdmo/z1geC3PMH12KIO3Y2ZiWnGv0m1L2knM2pBua1423BBRuARM9CISA2AigEEA2gAYLiJtTGadrpTqoP99Zph+zjD9htCEbd9TMzaZTldKIScvP6zr/ssHK/D49I32ZmamjxoPfZ2AZ2aa7zfRpv97S/H+ot2RDoOinJ0SfVcASUqpZKVUNoBpAIaGN6zwe/GnbWjx/LxIh8H780Sh+duO4LuEyNUEg7H7yJ94f9GeSIdBUc5Ooq8P4IDheZo+zdstIrJZRGaKSEPD9DIikiAiq0XkRrMViMhIfZ6EjIwM28EXRbS1iSsW6YlCbtXeTGTnhrfmHqzcvPxiv8YnVJ2xPwOIU0pdAWAhgC8NrzVWSsUD+CuA90WkmfeblVKfKKXilVLxNWvWDFFI5pIz/sTxM9lhXUcw+GMpROGxNf0khn+6Gm/Mi64Rbc/M3IwrXlpQrOu0k+jTARhL6A30aW5KqUyl1AX96WcAOhteS9f/JwNYAqBjEeItlANZZ92Prxm/FAPeX1bcIQTEYfTkBDsPn8LQiStx5kJupENBpl6gSzr6Z4Qj8fSDzQEaoWQn0a8D0EJEmohIKQDDAHiMnhGRuoanNwDYoU+vKiKl9cc1APQAsD0UgQfDmOgB4OjpCxZzhk/acc8YElKycPp8jruNPjdfIW7MHKxPLb6LvE6czcaRU+eLbX0EzNqQjoSUrLAs+9T5HKSfOBeWZdv1xtyd2HTgBNaG6TM6SXFeJBkw0SulcgGMBjAfWgL/Tim1TUReERHXKJq/i8g2EdkE4O8ARujTLwOQoE//HcA4pVSxJ/rkY2eKe5UeJv6ehJ5v/o7le7T+h5PncnDrx6vwyLfrfeZduqt4+igAoOvri3HlfxYX2/oIeHz6Rtz68aqwLPv6CSvQY9xvYVn2xcgqkZ7PyQtbASc/XyHX5mi+fIVi6z+w1UavlJqrlGqplGqmlHpdn/aiUmq2/vhZpVRbpVR7pVRfpdROffofSql2+vR2SqnPw/dRNHUqlfGZNnbWVqQdP4u3ft1p+b5NB064EzGgJeNgZOfm49kfNuPoad8d6O35uwAAOw+dds8LADsOnQpqHaGWHebhpVS8Ur1qrpeCxP1ZePnnbX7nEa+hbSMmrw1bAefRKevR3OZovud/3IKWY4tn5J/jrozNtziLj56yAR8u2Wv5vqETV+Kuz9ciP19hxZ5jaP/yAizZddT2ehdsP4ypaw9g9JQNtt+jlG/bvFn0qZln0e0/i3EwTNXysbO2FPq9v249hAXbDocwGiL7bvloFSavTEHcmDk4edazcGbVMLI62X6zUm5eflCl7nlb7R8L09YdCDxTiDgw0ZtPv2Dzy+r82kLc+fkaAEDift/2cqUUnv9xCzYdOOE1Xfu/dp/1TuQaQul37LzJiWrqulQcPnUePwboxEk7fhb/t2hP0G1/36xOBQA8M3MTpqxJtfWerDPZOHMhF6O+WY+RXycGtT4zv+88ihV7jhV5OYXx08Z0/LA+dOPm8/MVHp2yHnsz7HUCXsjNw8zEtEvmxnbnsvOwfE8G8vMVEveHri1/15HTptOLMq7tlo/+KLZSdzg5LtFbHSxWzSRxY+Z4XAV53KtUkJzxp0dzzOkLufh2TSru/GyNx3yFufDJLNJgD/U/L+Ti8EktvlHfJOK9RbuxNyNwn4RZbeW7hDQ896N16X7TgRNIOqodTJ1eXYjr3l0aZLTW7v1infsEGyylFJbvySh0ovzHtI148rvQXQn79oJdmLP5EK4db2/7vL9oD56asQnzQ1Qz+nBJUkiWUxT3Tl6Hj5ea16DHztqKuz5fi+d+3IJbPlqFZbuLr18qWJvSTrofr07OxDt6M+zFxnGJ/qpm1YN+j7+rIK8ZvxRdXy9ozwtFoct1TsgzqX54Lz9xfxYW7zhiuazrJyxHtze0+M5l5+nLCBzkiMnB36t/6MSVuO7dgqGph0769kf8uvUw9md6nmhW7c3Eqr2ZOHr6fJGT2fR1qT7Ln7r2AO76fC1+2niwSMsOlc1pJ4KaP0MfBXbqXGiGJH672rpW9vOmg9h28KTl6wBw9PR5vDR7m+1OxQNZZ92fwWjcvII+sY+X7sXC7dp+nKTXdFwjzA6d9GySPJ+T5y68BCM3Px+fLkvG+RztOJiRcMA93cya5Ew8M3OTreMlL19h2Cer8cHv1idRs+M5Wjgu0TetWSFky9ro1TwDAO8t1O8r4lWCN37J62wOLTPr8PW+QvaWj1Zh9xHrJoCUzIIOOGOnU8qxM/h6VYqtOEJp1DeJ6PfeMpzLzsPHS/cibswcDP90NYZ/uhq3fPQHHvo6EamZ1p2Gz/5gXqM4fiYbSin86/stuOnDPzxeO6APXQ3X0MJtB08GdSW1RV6xZFYZTD9xzueE5u3YnwXJddKKfT6vn7mQi31eI84em7oBQyas8LvcF2ZtxRd/pOB3fQRYyrEz+Me0DcjO1X43YddhzyaSq9/6HV1eX4QHvkzAUovS+bh5O/HgVwmmr7nybH6+glIKj3673l14Ccb3iel4fe4OTNST8dwtWqFiZVKmTw0cAO74ZDW+S0izVXgz3pzwtMVVrcYTfNrxs7ZrmPn60OrJK32/w1BxXKLv1aJGyJa13KTN+Is/UtyPs3Pz3aUjY6K36pxcmZSJI6fOeyTkUCUn43Cxfu8tQ593luCFn7bZ7lD+deuhkMQBaNvlzV93epToAOBAlvZZX5+7HfO2HPIpyQHA1LW+pdElu46i46sL3d9HViGvbM788wKGTFjuc01DIEMmrMALs7YCAI6eOu8umVpJO1G40S/Gk3yPcb+h99tLsMhiXcf+vID41xa5n7/yS8Go5fQT53D6fA7u+nwN+r6zJOg4vEumfd5Zgp82HsT9X67D/5YlY8D7y/DQ175Je5Gfmqc/nyxPRn6+QtPn5uL+LxOweKfnPvvKz9tx9Vu/IS9fYWZiGvL0E8KczZ77rOsirdPnfWtGK5J8j2XXYWj8tEopvLtgl8/wy38abo746fLACbnnm7/b7mzN0UsGb8y1HhVYVI5L9M1CWKIP5LU52zFkwgpMX5fq0UbvPZzLZenuDPzlvyvwm2FHTj/umez8FQJEtGpu3Jg5PiWnpKN/mpYMdxwqKH2dPJuDD5ckmZY0Rn3jO6bfJSEly1ZH6XcJBTu2v2Scr4CHv12PWz+yN57c1Sm+IfWErfmt/LghHdsOnsKkFSm25v90WbLPtGGfrMaDXyWYNmucy87Dgayz7hMaAHdyyvdTrXcnHJNZHjApBefm5Qcs7X+fmIb1hdxerji8R5ss33PMffKev61wSf18Tp7PfpqccQa/bNGStvHYOHk2B//+aSsmrdyHA1nn8NWqFDw1YxOmrE3FrI3peHSK5z7rOlGeOpdj614yrjimrzuA33Zqn2dT2klM+C0pJMMvE/cfx/+W7g14ZW5x9MHHhn8Vxatq+VJhWW7WmWyUKel5XvxqlVad/9f3W/DeHe3d0z9ZlmxZUj96+oLHrZO9zwkK2nDK2BhBvSplfd7vSnbT1qaid8uC+wJNXZtqObT04Ilz6D7uN1QsE4vT53PRrn5ly88JaG3Gufn5OJB1Dj9vOujTbOE9jA0Adh85jWdmbnY/93eguUrE6SfO4ZvV+3Fnt8Yer6ccO4MRk9fiu1FXoVbFgusiAt347deth9GsZgUMvLyONr9SeG/RHgzv2hB1K5f1OOkBwJu/7sTVza1rgK/P3YEbO3rev2+/11j1fcfOYHPaCQztUB893/zNfdm9S7Pn5gLQDvon+rUwXU+g+x1N/D0Jj/ZtDkBrolm4/QhWJWf6fY+3jQdO4JTNa0OW66XfN3/diSFX1LWc7/DJ86hT2fe6FX9av/Cr+7Gx2fHxab7Dkt+avxPfGkaBHTmlNVW9PHsbqpkc566RdT9sSMfsTYH7a0QEUMo9ACFl3BDLY8iDUpi1IR2PT9+IrS8PQLmSMRi/cBfaN6jiMdvq5EzMTEzDG/N2ImXcEMvFuQttYbztleMSfbh0enWhx3Pv78T7YPWuVlo5YZI0e739OwDgg7963hZowuI9GH9bB9Pl/GKxPpGCk4OrSuvdxuqty+uL/L7e/hXfGzJdyPEs/S2xeYXv2Flb0aFhFY9pk1buQ0rmWczdfAgjejRxb9lzeieblS3pJzHqm0T3QTV66gbM2XwIK5OO4fuHu+N7fQil6+T60ZK9+MjPtRUA0Ef/LgBtu7iaNRS0tlpX08jQDvV9krzR1LWpmLvF/z6hoDUvXeM1Wuft+btwTetaqF6hlEcTTTBunLjS7+sLtx/Rrtju3MBdkg90AVa3Nxb7TWCBGGsMZhWe3DzPia5RPLn5yvQ2JsZ9LtdkgXFj5mBwuzru595NVFPWpOKyuhULYrKohU34raBDNv34OWSdycbE3333ozRDbf30+Rx8uyYVI69u6jOfq/M4Ozcfb8zbgWcHXWa63qJwXNNNpNj+gREvPtVPw77lffHV+Zx8e2PxDTL/vIBnf9jsMe21KPt9Wu8E6V2ociX4/y0taEqx0+HtOtleyM3zaK4KpuB0Jrvg5GIcWfLE9I1oZ7gDoZ0fsfHufN+QehxT16Z6NN3cPWmtaSf91vSTQXfyulhd8LNqbybixszBhtTjePCrBNMf6cn80/99oeLGzPH7+l/+u8I9JDcauDpozYydtcWj2TVQwQLQaplWo3qMXvtlB8bN2xnwR2I+t9H+Xxgs0RfSKZMOn1AI1DzhSih2b29sp+Mo0rz7DFyjnX7RS/Rmn+G2j1dh68sDUKF0rGkbp/fYbFcTCgB8tmIfPjMZpZKaeRaNqpezFbN3Deq8jaTgzTV6aHhX7eaw2w+dxLaD5td7PD1zs+l0O8zaiJ/7cQsqlon1iAOAz4VjfxbxLpRb0k96DMkNRiR+o8F4VNkpTCVnnLG1jU5f0E7extqAC9voL0EpAW7A9sR0rdQ1Z8shVC3CrQtCze6QUju2pGsjmRL2H/dbYvx29X68MW8nmtfy7YC/e9Ja92OBWF4xbfTTxnQ8dq15O3qoGU9uU9dqndjf+Bn/HqyXfi5o4hk8YbnP61PWpGJ410Y+070vHOv99pKQxRSs4r53j/cgimOnA4/uMrsxoRl/NYlkm1dQFwWbbqJMMKMZQpkYiqqwbcdA4Uurb+gjQLxLrF8V8vqBovysY7A1J6vL9YtTNF+RCgR3T5pQyMtXeNNw48Nehv6ZcDKW8sP106Is0VPEmV1VWRQv/uR5N0NXDSGQdxbsRm2Tu5/aMWFxcL/bOvB931J2cYv0veuj0R97gxvNdLFgiZ7IoCht4URFlZMXngZ7JnoiIodjoicicjgmeiIih2OiJyJyOCZ6IiKHY6InInI4JnoiIodjoicicjgmeiIih2OiJyJyOCZ6IiKHY6InInI4Ryb6quVKRjoEIqKo4chE/+MjPSIdAhFR1HBkoo+rUR4f/q1TpMMgIooKjkz0AHDdZbUjHQIRUVRwbKIvFevYj0ZEFBRmQyIih3N0om9Rq0KkQyAiijhHJ/ouTapFOgQioohzdKJXKjw/tEtEdDFxeKKPdARERJHn6ETfu2XNSIdARBRxthK9iAwUkV0ikiQiY0xeHyEiGSKyUf97wPDaPSKyR/+7J5TBBzKoXd3iXB0RUVQKmOhFJAbARACDALQBMFxE2pjMOl0p1UH/+0x/bzUA/wZwJYCuAP4tIlVDFr0NT/ZrWZyrIyKKOnZK9F0BJCmlkpVS2QCmARhqc/kDACxUSmUppY4DWAhgYOFCLZzL61cqztUREUUdO4m+PoADhudp+jRvt4jIZhGZKSINg3mviIwUkQQRScjIyLAZuj3xcRxiSUSXtlB1xv4MIE4pdQW0UvuXwbxZKfWJUipeKRVfs2ZoO1ArlSmJyfd2CekyiYguJnYSfTqAhobnDfRpbkqpTKXUBf3pZwA6230vERGFl51Evw5ACxFpIiKlAAwDMNs4g4gYh7fcAGCH/ng+gP4iUlXvhO2vTytWVzWtXtyrJCKKGrGBZlBK5YrIaGgJOgbAJKXUNhF5BUCCUmo2gL+LyA0AcgFkARihvzdLRF6FdrIAgFeUUllh+Bx+lSkZU9yrJCKKGrba6JVSc5VSLZVSzZRSr+vTXtSTPJRSzyql2iql2iul+iqldhreO0kp1Vz/mxyejxGcNnUrYevLAyIdhl+j+zaPdAhE5MeQKy6e63QcfWWsmUf6NMOMUVehQmn/lZmxQy5DL5tX1v53eMdQhAYAaFC1LACgZ4saIVsmWatfpWykQ6CLVMeGVSIdgm2XTKJ/4fo2+OuVjfDMwNYo75Xkt5mU7h+4uqmt5b564+UYUogrcKuXL4V1z1/nMW103+aoF0Ti6d7s4ul7qFQmYCthoYzoHlek939RxBFZPZsX7wm5dJT8oM6jfZuFdflXXwQFnbz8i+dmWtGx1xSD+3s2wX9uaucxbfboHnhusG/i/+GR7gAA8VrG+Nva+yy3pZ973o/oHmfZEXxjx/qoWbG0x7QWtc2X1dXidstf338lalfSllG5bEnLOFz+O7wj+rSKzP1/qpYvFZbltqlbtAvixPtLBvD0gFa23399GKrv7epXxl3dGmPXawNRzWu7dSnidSHDuzYMPFOYNK1R3va8Jcy+mDBpVtN+XEbNL6Lfu7hkEr2ZKxpUwcheWsnEWIWvWUFLnsbzdcq4IbilcwOfZYifHfKlG9rik7s7m7422KQWICJoXaciAKBquVK4Pb4BRnSPw7Au5gdnTImCdc8YdZX78aDL61jGFMzhUz1MydnlsWt8+yGaBJEMAEDBt1QVW0IwoG3hfzP4UYv+kfYmVfUbO2rX/zWuXq7Q6/P2/cPd8eqNl6N0bAzKWgwkiG9c1eck4M+OVwZi3M3t8FAv85L4u7f7FmKMunqdYDo0DP5OJte1sf+dFFeeLx1bAjd38j2u7bj2Ivpd6ks60RvNfPgqvHXLFfjs7ng0rGb/oLXaISeP6KK/7jlDy9oVMP/xXujc2PdAubJJNYwd0gYzRl2FVnUq4q1b2+OlG9ripo5mFyJr7ojXTgLGE9UHf+3kJ96CeFLGDbGc7+kBrfDzYz3x8Z3WyyqM9g2r4M1b2uHDv3XCP/v7lpx/Gt2jUMsddHkdxOvb9JsHrsT42zsUJUxbFjzRC2VKxuDd29tj6oPdLOcL9qTj7/eOOzWqAgD4z83tfGqo3p4d1Nr9uGypGAzr2ghxFifSQMmuZ4sauKF9PQDAc4Nbo1+b2pg20vozm7Gbu81+Ga5htbLoa1IbDdTHsuGFfn6nj72+DR7u7Xny864lLnmqj9912PHi9W0wspdnc3CXuOK77RcTva5u5bK4vUtDj1KH2Y55betaAIDHr2uhv6+M6fKs2trLlIxBK73UbpQybghqVyqDUrElfKrn/moNT/Rrid2vDfJpfrLSsKr1gTHbkGRv7lQf9aqUxcDL6wZsHoktEfgQNs5xR5dGPjWaFf/qi4VP9EKlMgVNUD2aVzc9uM1ULBOLe3s0AaAligqlYz3aeb9/uLvpshpULaf/958wvri3i8cPHPRsXgMta2vf482dGqBelbJY/M/e6KgnYqP/3RXvMy2uejn0MynhPje4tcdz76/+H9e1xMIneqFl7YoYeHkd7HtjsGXMD/VuhmcGtsJMQ22vsJQCuunNkFc11bZrB72GY2zq8tsXY7GbXHdZLc/ZBD41mWY1K2DyvV1xcyet0OM67iZYDIQYO+QypIwb4tNkuPb5azHlgStRtXwppIwbgru6NUYJw/778+ie7qZb13PjyfGDv3quz24z3309m+C5wZe5n68ccw2+uu9KdG9WHTUqlPI49sIhPD1kDvbhnZ1w8mwOalQojbu6NUb1CqWR79Ups/ifvdGsplYqiQlxHbRToypYn3oCz+s7jYigVKy2jpdvaItXftkOY95d9nRfzN16CJNX7kPP5jXQv21tfLlqv+myr2hQxbSU/9PoHvjf0r14Z8Fu0/f956Z2eOb7zT7T29arhG0HTwEALqtbCfWqlPXY2Y1cCdfo2we0EuOG1OO46cM/TN9nNOSKuhhyRUH8t3RqgOV7junPPBt54htXxcyHtQP6q/u6onWdiuj6n8Xu10d0j8MXf6QA0Jot+rSqhfcWFnx+syajZjUr4OHezTDy60TLGKuXL4XMM9mYMao7zlzIxcLtRzDr0R64ceJKbblei61eoTTSjp9zP48pIWhRu6CgYCwE/Gtga3y6PBljBrZ2NzM90qdow3Q3v9QfL8/ejvt6xqFC6Vj0bV0TdStrJ8UyJWPc+8vb83ehRoXSWPBELzz/4xbM23oYvVrWxLLdBfeuEssyveDpAa2Qn68wfuFuCMSn0/k+/ST+7u0d8K5eW7uQm4fSsZ4nhFmP9kBq1ll37cNbrYplUKuieeEMANo1qGz6vEq5kjhxNgfdm9XA+Nva4/jZbABaM9/b83dZLs+KqyYyxU9NMJSY6P247rJaWLrb8yZrpWNjUKuStnNVr1Da7G3uJA9oVeZXh7ZFlXKl8NjUDYWOZdOL/bHn6Gks3HEE61NPoINJyfGe7nG4x2sUSqPq5TCqdzOMMlRP29SthO2HTtled8kYz1rGv//SBi//vN393CzpAVq7tSvRl4otUeidumMj8yru1S1qoFUdrbbRzaTT2+oc+9g1zT1GVZkNo728fsEB/8/+LX0W2LSGeUdc/7Z18PPonvjLBysAADW89pGaFUsj80w2YkoI4mqU9zmxejfbfHp3ZyzYdgRjZ201/zAGD/dphof7BDcaZtGTvdDE67MsfboPer+9BIB2r6jxhvZ7V5L3NvXBbmhaszyqlS+F/w7viB2HTqNdg8rYfvAUJq/chxmJaZbfh4iWMHccOoXxC30LE1ZNjN5JHtBqGR0shj2a1basvH7T5fjgtyT3c9cJWADTvjork0bEIznjjPv51/d3xa7Dpy3nH9jWun+tKJjo/bizW2O88NO2oN5jVnW966o4HMg6CwDoZJG0AqlcriTi46qhfcMq6N6sRpFGX/z8WE/kB/k7i66SY5e4qri3RxOPRG+ld8uaWJmUiZPncoLqBPbn8vqV8Eif5tiUdgJP9muJ0rExSBh7nU9CBYCBXp3SQzvUw5JdGbijS0PTUUqf3h2P3Uc8D8KbO9XHlfpJ5LG+zfHAVwn4+M5O6NOqls/7XYylwoSxnkNov7yvK5bvOebTkZr0+iBMWZuKYV0aeUyvVbEM7uzW2G+iv7Z1LSzeedTydW9rnrsWK/Ycw+B2dVG2VEGyXPRkb1QuW9JnNJgdVxmG+sbGlHBvgzb1Krk7qgVabffa8Uv9Lst4QigVE7rW5RkP2W/C+tuVjfG3Kxu7n7t+f9rsZFWnUhkcPnXe/Xz7KwPQ5kXtTi/XtK6NawytcVe3qImrW5g3R+55fVDIWwBcmOj9EBEkjr3Obxu5tzl/v9p0esNq5TD/8V5oWsihXC4lY0oU+ScSY0oIYrxSb7Djs/9vWAckHf3TXYUFgNs6N8CMxDTDekrgpRva4InpmyzK/Pa9duPlALSTL+A5asksyQNaie+nR3vghZ+2om29yujcuBpu6mhdGuvXprZpu7nLdW1q++3AtqN2pTK41aREGBtTAndfFWf5vmkju2Fr+knT1z66szPO5eQFFYNZqTRcwwXdpWHRarvjb2uPf87Y5H7d39E17hb/Hc52TB/ZDYdOnkdsCE4aZs1Py//VF0oBLcfOAwCUK1W4tFoyhCc1b0z0AVg1z1jxN2LHrBP2/4Z1wKlzOUHHZUeP5oEvqEoYex3Kloyx3ZnrMrRDwUig6etS3Y/HDrkMr83R7mlXvUIpnDCcCIrCleCD1b5hFcwe3TMkMURSt6bVTZunAK25x99InWjhSpLGWgRQUEo2q2SGooB7pY2bGj7UqykOHD9r+fqkEV3wxR8pqGhSYw9ngg4VJvoIMybMUFr/Qj+ULx34Zm5WpeHCENGuKB7RPQ5Ld2egb6ta+HFDmt/3vHt7e6RkWh9gFBnv3NYeszcdDMmyvPO3q9mqcfVy2J951k8nbfF51mKQgEt8XLWL+keMmOhDwLUjF+PFfAEFczFNUXkfqLExJdwXkwQ6iAt7sUq0e+e29pYXO10Mbu3cwLSJqTDu6NIQP286iL9eqfU/dGtaHZPv7YJT53Lwj2kb3ceNayTK37o1RmJKsd/ktsheHdoWbepF50+XMtGHgCuVucbYX2pu6FAPifuP45mBvmOKXbdvGN61kc9rThaqJOkEtSuVwcIne3tM69uqFn7ZrNUYXIm+crmS7j6QhIsw0d9l6GP5+7UtIna7ETNM9CFQooRgxb/6hrQZ5GJSpmQM3rz1CtPX6lUpW+QOzEgoU1Jrd60YZN8FBS9UTTd/v7ZFSJYTCk/2axnpEDxwLw4Rswt+6OI1+PK6eHbQOdx1VeE6gSmwxtW0EWidTG4HEqyezWtEXXKNJkz0ZEvHRlUwtEM9/COKSk3hVKKE4KHe4b0V76WuXYPKWPp0HzTyM1LNTml/2dN9CzX2/1LCRE+2lIwpgf8bFrofWCECgMbVza8ridOn22kObRTCO4c6FRM9RZ2yJWMs71VCl4bHrmmOTo2r8pfWQoSJnqLOjlcHRjoEirDYEFwBTgWi/5IuIiIqEiZ6IiKHY6InInI4JnoiIodjoicicjiOuiEqBu/d0R61K1n/hB1RODHRExUDfz94QhRubLohInI4JnoiIodjoicicjgmeiIih2OiJyJyOCZ6IiKHY6InInI4JnoiIocTpVSkY/AgIhkA9hdhETUAHAtROKHEuILDuILDuILjxLgaK6VMb+IfdYm+qEQkQSkVH+k4vDGu4DCu4DCu4FxqcbHphojI4ZjoiYgczomJ/pNIB2CBcQWHcQWHcQXnkorLcW30RETkyYkleiIiMmCiJyJyOMckehEZKCK7RCRJRMYUw/oaisjvIrJdRLaJyD/06S+JSLqIbNT/Bhve86we3y4RGRCu2EUkRUS26OtP0KdVE5GFIrJH/19Vny4iMkFf92YR6WRYzj36/HtE5J4ixtTKsE02isgpEXk8EttLRCaJyFER2WqYFrLtIyKd9e2fpL9XihDX2yKyU1/3jyJSRZ8eJyLnDNvt40Drt/qMhYwrZN+biDQRkTX69OkiUqoIcU03xJQiIhsjsL2sckPk9jGl1EX/ByAGwF4ATQGUArAJQJswr7MugE7644oAdgNoA+AlAE+ZzN9Gj6s0gCZ6vDHhiB1ACoAaXtPeAjBGfzwGwJv648EA5gEQAN0ArNGnVwOQrP+vqj+uGsLv6zCAxpHYXgB6AegEYGs4tg+Atfq8or93UBHi6g8gVn/8piGuOON8XssxXb/VZyxkXCH73gB8B2CY/vhjAA8XNi6v18cDeDEC28sqN0RsH3NKib4rgCSlVLJSKhvANABDw7lCpdQhpdR6/fFpADsA1PfzlqEApimlLiil9gFI0uMurtiHAvhSf/wlgBsN079SmtUAqohIXQADACxUSmUppY4DWAhgYIhiuRbAXqWUvyugw7a9lFLLAGSZrK/I20d/rZJSarXSjsivDMsKOi6l1AKlVK7+dDUAv79JGGD9Vp8x6Lj8COp700ui1wCYGcq49OXeDmCqv2WEaXtZ5YaI7WNOSfT1ARwwPE+D/6QbUiISB6AjgDX6pNF6FWySobpnFWM4YlcAFohIooiM1KfVVkod0h8fBlA7AnG5DIPnARjp7QWEbvvU1x+HOj4AuA9a6c2liYhsEJGlInK1IV6r9Vt9xsIKxfdWHcAJw8ksVNvragBHlFJ7DNOKfXt55YaI7WNOSfQRIyIVAHwP4HGl1CkAHwFoBqADgEPQqo/FradSqhOAQQAeFZFexhf1UkBExtXq7a83AJihT4qG7eUhktvHiog8DyAXwLf6pEMAGimlOgJ4EsAUEalkd3kh+IxR9715GQ7PwkSxby+T3FCk5RWFUxJ9OoCGhucN9GlhJSIloX2R3yqlfgAApdQRpVSeUiofwKfQqqz+Ygx57EqpdP3/UQA/6jEc0at8rurq0eKOSzcIwHql1BE9xohvL12otk86PJtXihyfiIwAcD2Av+kJAnrTSKb+OBFa+3fLAOu3+oxBC+H3lgmtqSLWJN5C0Zd1M4DphniLdXuZ5QY/ywv/PmancyHa/wDEQuuoaIKCjp62YV6nQGsbe99rel3D4yegtVcCQFt4dlIlQ+ugCmnsAMoDqGh4/Ae0tvW34dkR9Jb+eAg8O4LWqoKOoH3QOoGq6o+rhWC7TQNwb6S3F7w650K5feDbUTa4CHENBLAdQE2v+WoCiNEfN4V2oPtdv9VnLGRcIfveoNXujJ2xjxQ2LsM2Wxqp7QXr3BCxfSxsibC4/6D1XO+GdqZ+vhjW1xNa1WszgI3632AAXwPYok+f7XVAPK/HtwuGXvJQxq7vxJv0v22u5UFrC10MYA+ARYYdRgBM1Ne9BUC8YVn3QetMS4IhORchtvLQSnCVDdOKfXtBq9IfApADrX3z/lBuHwDxALbq7/kA+hXohYwrCVo7rWsf+1if9xb9+90IYD2AvwRav9VnLGRcIfve9H12rf5ZZwAoXdi49OlfABjlNW9xbi+r3BCxfYy3QCAicjintNETEZEFJnoiIodjoicicjgmeiIih2OiJyJyOCZ6IiKHY6InInK4/weSmu2tJtBRHwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# 展示模型损失优化的过程\n", "plt.plot(torch.tensor(lossi))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([32561, 2]), torch.Size([32561, 2]), torch.Size([32561]))" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 关闭梯度追踪\n", "with torch.no_grad():\n", " logits = model(x) \n", " probs = F.softmax(logits, dim=1)\n", " pred = torch.where(probs[:, 1] > 0.5, 1, 0)\n", "logits.shape, probs.shape, pred.shape" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(0.2928), tensor(0.5902), tensor(0.3914))" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from torcheval.metrics.functional.classification import binary_recall\n", "from torcheval.metrics.functional import binary_precision, binary_f1_score\n", "binary_recall(pred, y), binary_precision(pred, y), binary_f1_score(pred, y)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# 展示如何利用排序数据得到对偏好的估计\n", "# 此处只做模型结构展示,并不训练和使用模型\n", "class PreferenceModel:\n", " \n", " def __init__(self, pref):\n", " self.pref = pref\n", " \n", " def __call__(self, x0, x1):\n", " self.out = torch.concat((self.pref(x0), self.pref(x1)), dim=1)\n", " return self.out\n", " \n", " def parameters(self):\n", " return self.pref.parameters()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# 预测偏好的模型\n", "preference = Linear(5, 1)\n", "# 将两个数据的偏好组合在一起,以便和排序数据结合在一起\n", "p_model = PreferenceModel(preference)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[0.0679, 1.1044]]), tensor([[0.2618, 0.7382]]))" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 随机选取x0和x1\n", "x0 = x[[0]]\n", "x1 = x[[1]]\n", "p_logits = p_model(x0, x1)\n", "# 得到有偏好推导出来的排序概率,该数据可以与观测到的实际排序相结合,定义模型损失\n", "p_probs = F.softmax(p_logits, dim=1)\n", "p_logits, p_probs" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }