{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torch.nn.functional as F\n", "from sklearn.datasets import make_moons\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "torch.manual_seed(1024)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# 定义线性模型和Sigmoid函数\n", "\n", "class Linear:\n", " \n", " def __init__(self, in_features, out_features, bias=True):\n", " '''\n", " 模型参数初始化\n", " 需要注意的是,此次故意没做参数初始化的优化\n", " '''\n", " self.weight = torch.randn((in_features, out_features), requires_grad=True) # (in_features, out_features)\n", " self.bias = torch.randn(out_features, requires_grad=True) if bias else None # ( out_features)\n", " \n", " def __call__(self, x):\n", " # x: (B, in_features)\n", " # self.weight: (in_features, out_features)\n", " self.out = x @ self.weight # (B, out_features)\n", " if self.bias is not None:\n", " self.out += self.bias\n", " return self.out\n", " \n", " def parameters(self):\n", " '''\n", " 返回线性模型的参数,主要用于参数迭代更新\n", " 由于PyTorch的计算单元就是张量,\n", " 所以此次只需将不同参数简单合并成列表即可\n", " '''\n", " if self.bias is not None:\n", " return [self.weight, self.bias]\n", " return [self.weight]\n", "\n", "\n", "class Sigmoid:\n", " \n", " def __call__(self, x):\n", " self.out = torch.sigmoid(x)\n", " return self.out\n", " \n", " def parameters(self):\n", " '''\n", " Sigmoid函数没有模型参数\n", " '''\n", " return []" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "class Sequential:\n", " \n", " def __init__(self, layers):\n", " # layers表示的模型组件,比如线性模型,比如sigmoid\n", " self.layers = layers\n", " \n", " def __call__(self, x):\n", " for l in self.layers:\n", " x = l(x)\n", " self.out = x\n", " return self.out\n", " \n", " def parameters(self):\n", " # k = []\n", " # for layer in self.layers:\n", " # for p in layer.parameters():\n", " # k.append(p)\n", " return [p for layer in self.layers for p in layer.parameters()]\n", " \n", " def predict_proba(self, x):\n", " # 计算概率预测\n", " if isinstance(x, np.ndarray):\n", " x = torch.tensor(x).float()\n", " logits = self(x)\n", " self.prob = F.softmax(logits, dim=-1).detach().numpy()\n", " return self.prob" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.5756, 0.0729],\n", " [-0.5902, 0.0307],\n", " [-0.5812, 0.0562]])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# x: (B, 2)\n", "# mlp: [4, 4, 2]\n", "model = Sequential([\n", " Linear(2, 4), Sigmoid(), # (B, 4)\n", " Linear(4, 4), Sigmoid(), # (B, 4)\n", " Linear(4, 2)\n", "])\n", "x = torch.randn(3, 2)\n", "model(x)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0.34332383, 0.6566762 ],\n", " [0.34957516, 0.6504248 ],\n", " [0.34582347, 0.65417653]], dtype=float32)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.predict_proba(x)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def draw_data(data):\n", " '''\n", " 将数据可视化\n", " '''\n", " fig = plt.figure(figsize=(5, 5))\n", " ax = fig.add_subplot(1, 1, 1)\n", " x, y = data\n", " label1 = x[y > 0]\n", " ax.scatter(label1[:, 0], label1[:, 1], marker='o')\n", " label0 = x[y == 0]\n", " ax.scatter(label0[:, 0], label0[:, 1], marker='^', color='k')\n", " return ax\n", "\n", "def draw_model(ax, model):\n", " '''\n", " 将模型的分离超平面可视化\n", " '''\n", " x1 = np.linspace(ax.get_xlim()[0], ax.get_xlim()[1], 100)\n", " x2 = np.linspace(ax.get_ylim()[0], ax.get_ylim()[1], 100)\n", " x1, x2 = np.meshgrid(x1, x2)\n", " y = model.predict_proba(np.c_[x1.ravel(), x2.ravel()])[:, 1]\n", " y = y.reshape(x1.shape)\n", " ax.contourf(x1, x2, y, levels=[0, 0.5], colors=['gray'], alpha=0.4)\n", " return ax" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "data = make_moons(200, noise=0.05)\n", "draw_data(data)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "batch_size = 20\n", "max_steps = 40000\n", "x, y = torch.tensor(data[0]).float(), torch.tensor(data[1])\n", "learning_rate = 0.1\n", "model = Sequential([\n", " Linear(2, 4), Sigmoid(), # (B, 4)\n", " Linear(4, 4), Sigmoid(), # (B, 4)\n", " Linear(4, 2)\n", "])\n", "lossi = []\n", "\n", "for t in range(max_steps):\n", " ix = (t * batch_size) % len(x)\n", " xx = x[ix: ix + batch_size]\n", " yy = y[ix: ix + batch_size]\n", " logits = model(xx)\n", " loss = F.cross_entropy(logits, yy)\n", " loss.backward()\n", " with torch.no_grad():\n", " for p in model.parameters():\n", " p -= learning_rate * p.grad\n", " p.grad = None\n", " lossi.append(loss.item())" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "ax = draw_data(data)\n", "draw_model(ax, model)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/OQEPoAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWyElEQVR4nO3de5gdBX3G8e+7u9kkhHAJWQVJcEGDGgQVV5AHL6ioIfAQW7GGttZ6S63SeqHVUAQp2hbwEfGCl6gU8QICapvHxAIVLHgJZCMQkmBgSYIkIFkIkATIZcmvf5wJObuc3T17ds6ZM7Pv53n22bmdmTdzNu+enTlnRhGBmZnlX0vWAczMLB0udDOzgnChm5kVhAvdzKwgXOhmZgXRltWGp06dGp2dnVlt3swsl5YtW/ZIRHRUmpdZoXd2dtLd3Z3V5s3McknS/YPN8yEXM7OCcKGbmRWEC93MrCBc6GZmBeFCNzMrCBe6mVlBuNDNzAoid4W++k9buPj61TyydXvWUczMmkruCr1n41a+cmMPm57ckXUUM7OmkrtCNzOzylzoZmYF4UI3MysIF7qZWUEMW+iSLpO0UdKKQeb/laTlku6S9FtJr0g/ppmZDaeaV+iXA7OGmL8WeGNEHAl8DliQQi4zMxuhYa+HHhE3S+ocYv5vy0aXANNSyGVmZiOU9jH0DwC/GGympHmSuiV19/b2prxpM7OxLbVCl/QmSoX+6cGWiYgFEdEVEV0dHRXvoGRmZjVK5RZ0ko4CvgOcFBGPprFOMzMbmVG/Qpd0CPBT4D0Rcc/oI5mZWS2GfYUu6UrgBGCqpPXAZ4FxABHxTeBc4ADg65IA+iKiq16Bd4uo9xbMzPKlmne5nD7M/A8CH0wt0TBKvzPMzGwgf1LUzKwgXOhmZgXhQjczKwgXuplZQbjQzcwKwoVuZlYQLnQzs4JwoZuZFYQL3cysIFzoZmYF4UI3MysIF7qZWUHkttADX27RzKxc7grdF1s0M6ssd4VuZmaVudDNzArChW5mVhAudDOzgnChm5kVhAvdzKwgXOhmZgXhQjczKwgXuplZQQxb6JIuk7RR0opB5kvSVyT1SFou6ej0Y5qZ2XCqeYV+OTBriPknATOSr3nAN0Yfy8zMRmrYQo+Im4FNQywyB7giSpYA+0k6KK2Ag+eq9xbMzPIljWPoBwMPlI2vT6Y9h6R5kroldff29ta0MfnqXGZmFTX0pGhELIiIrojo6ujoaOSmzcwKL41C3wBMLxuflkwzM7MGSqPQFwJ/k7zb5bXAExHxUArrNTOzEWgbbgFJVwInAFMlrQc+C4wDiIhvAouB2UAP8BTwvnqFNTOzwQ1b6BFx+jDzA/hoaonMzKwm/qSomVlBuNDNzArChW5mVhAudDOzgnChm5kVhAvdzKwgXOhmZgWR20L31RbNzPrLYaH7cotmZpXksNDNzKwSF7qZWUG40M3MCsKFbmZWEC50M7OCcKGbmRWEC93MrCBc6GZmBeFCNzMrCBe6mVlBuNDNzAoit4Ue+OpcZmblclfo8rW5zMwqqqrQJc2StFpSj6T5FeYfIukmSbdLWi5pdvpRzcxsKMMWuqRW4FLgJGAmcLqkmQMW+wxwdUS8CpgLfD3toGZmNrRqXqEfA/RExJqI2AFcBcwZsEwA+yTD+wIPphfRzMyq0VbFMgcDD5SNrweOHbDMecD1kv4BmAScmEo6MzOrWlonRU8HLo+IacBs4PuSnrNuSfMkdUvq7u3tTWnTZmYG1RX6BmB62fi0ZFq5DwBXA0TE74AJwNSBK4qIBRHRFRFdHR0dtSU2M7OKqin0pcAMSYdKaqd00nPhgGX+CLwFQNLLKBW6X4KbmTXQsIUeEX3AGcB1wN2U3s2yUtL5kk5NFjsT+JCkO4Ergb+NiLp88ueJp3cCsHHz9nqs3swst6o5KUpELAYWD5h2btnwKuD4dKNVdu2y9QB86+b7eNNLn9eITZqZ5ULuPilqZmaVudDNzAoit4VenyP0Zmb5lbtC97W5zMwqy12hm5lZZS50M7OCcKGbmRWEC93MrCByW+h+k4uZWX+5LXQzM+svt4Xuty+amfWX20K/de2mrCOYmTWV3BX67qstmplZf7krdDMzqyx3hd7a4qPnZmaV5K7Q21r3RN61y29eNDPbLXeFXv4C/aqlD2QXxMysyeSu0O/buPXZ4Z/8fn2GSczMmkvuCr2v7DDLsvsfyzCJmVlzyV2ht8onRc3MKslfobe60M3MKslfofsVuplZRbkr9Ba/D93MrKKqCl3SLEmrJfVImj/IMn8haZWklZJ+lG7MPdznZmaVtQ23gKRW4FLgrcB6YKmkhRGxqmyZGcBZwPER8Zik59Ur8IuftzcPb97+7PjmbTvZZ8K4em3OzCw3qnmFfgzQExFrImIHcBUwZ8AyHwIujYjHACJiY7ox9zhkyqR+479a3VuvTZmZ5Uo1hX4wUP6RzPXJtHKHA4dL+o2kJZJmVVqRpHmSuiV19/bWWsT9P+5/4S/+UON6zMyKJa2Tom3ADOAE4HTg25L2G7hQRCyIiK6I6Oro6EhlwxsefzqV9ZiZ5V01hb4BmF42Pi2ZVm49sDAidkbEWuAeSgWfuvD1uMzMKqqm0JcCMyQdKqkdmAssHLDMf1F6dY6kqZQOwaxJL+Ye+0z0CVAzs0qGLfSI6APOAK4D7gaujoiVks6XdGqy2HXAo5JWATcB/xwRj9Yj8CdOPLweqzUzy71h37YIEBGLgcUDpp1bNhzAJ5OvuprY3lrvTZiZ5VLuPilaybadz2Qdwcwsc4Uo9N4t24dfyMys4ApR6P/ys7uyjmBmlrlCFPot9z6SdQQzs8wVotDNzMyFbmZWGLks9NlHHph1BDOzppPLQn/FtP2yjmBm1nRyWejHv3hq1hHMzJpOLgu90m1FH93q96Kb2diWz0LnuY3+9V/dl0ESM7PmkctCb297buzv/nptBknMzJpHLgv9hQfslXUEM7Omk8tCNzOz58plobe1VDgrCoRvZ2RmY1guC12V3uYCfPXGngYnMTNrHrks9MFcfMM9WUcwM8tMVXcsypPO+Yu46LSj+PW9j7DwzgczzdLWIg5//mRecuBkXnbQZF5y4D687MDJdEweP+hfGWZmtSpcoQN86trlWUcAoG9XsOqhzax6aDM/u3306/viu17Bnx99sH8ZmFlFhSz0ojrzmjs585o7nx1fft7b2GfCuAwTmVkzKdQx9LHmqPOup3P+oqxjmFmTcKEXQOf8RX7Lppm50Ivi0LMWZx3BzDJWVaFLmiVptaQeSfOHWO6dkkJSV3oRrVqnfPWWrCOYWYaGLXRJrcClwEnATOB0STMrLDcZ+Bhwa9ohrTorNmzOOoKZZaiaV+jHAD0RsSYidgBXAXMqLPc54EJgW4r5bIR8ktRs7Kqm0A8GHigbX59Me5ako4HpETFkm0iaJ6lbUndvb++Iw5qZ2eBGfVJUUgtwMXDmcMtGxIKI6IqIro6OjtFu2gax4fGns45gZhmoptA3ANPLxqcl03abDLwc+JWkdcBrgYU+MZqd4y+4MesIZpaBagp9KTBD0qGS2oG5wMLdMyPiiYiYGhGdEdEJLAFOjYjuuiQ2M7OKhi30iOgDzgCuA+4Gro6IlZLOl3RqvQNabXbt8geNzMaaqq7lEhGLgcUDpp07yLInjD6Wjdbcby/h6r87LusYZtZA/qRoQd22dlPWEcyswVzoZmYFkdtC/8tjD8k6QtN7xsfRzcaU3Bb65PG+lPtwfvL79VlHMLMGym2h45v2DKtZ7txkZo2R20I/cJ8JWUcwM2squS3013ROyTqCmVlTyW2hv2C/iVlHyIXHntyRdQQza5DcFvqUSe1ZR8iFV33uhqwjmFmD5LbQzcysPxe6mVlBuNDHgL5ndmUdwcwawIU+BnzhutVZRzCzBnChjwHfunlN1hHMrAFc6GZmBeFCNzMrCBf6GPHw5m1ZRzCzOnOhjxHH/vsvs45gZnXmQjczKwgXuplZQeS60CeOa806gplZ08h1oR//4qlZR8iVS2/qyTqCmdVRVYUuaZak1ZJ6JM2vMP+TklZJWi7pl5JemH7U5/r4iTMasZnC8CdGzYpt2EKX1ApcCpwEzAROlzRzwGK3A10RcRRwLXBR2kErmT5lr0ZsxswsF6p5hX4M0BMRayJiB3AVMKd8gYi4KSKeSkaXANPSjVnZvhPHNWIzZma5UE2hHww8UDa+Ppk2mA8Av6g0Q9I8Sd2Sunt7e6tPaalZfNdDWUcwszpJ9aSopL8GuoAvVJofEQsioisiujo6OtLctFXpIz/8fdYRzKxO2qpYZgMwvWx8WjKtH0knAmcDb4yI7enEMzOzalXzCn0pMEPSoZLagbnAwvIFJL0K+BZwakRsTD+mpSkiso5gZnUwbKFHRB9wBnAdcDdwdUSslHS+pFOTxb4A7A1cI+kOSQsHWV3q9h5fzR8ZVu7iG+7JOoKZ1YGyerXW1dUV3d3do17PPQ9v4W1fujmFRGPLugtOzjqCmdVA0rKI6Ko0L9efFAU4/PmTs45gZtYUcl/oVpt7H96SdQQzS1khCv2ASe1ZR8idt/owlVnhFKLQuz9zYtYRzMwyV4hCl5R1hFz6w582Zx3BzFJUiEIHv2ujFrMuuSXrCGaWosIUOsCif3xd1hHMzDJTqEI/4gX7cuOZb8w6Rq4cf8GNWUcws5QUqtABDuvYm3UXnMwpRx2UdZRc2PD401lHMLOU5P6TotV47MkdvOXi/2PTkzsasr28eXfXdC487aisY5hZFYb6pOiYKPR6eGZXsPnpnWzetpMnnt7Jlm19yddOntrxDFu397F1ex9Pbe9jy/Y+nkzGt2wrfX/8qdLjntnVHBfK8klls3wYqtB9ZasatbaI/Se1s38TfKgpIvj8orv57q/X1ryOzvmLXOpmOVe4Y+hjkSTOOWUm6y44mbX/Mbvm9Ty5vS/FVGbWaC70gpHEugtO5or3HzPixx7x2evqkMjMGsWFXlBvOLyDpWeP/JIInfMX1SGNmTWCC73AOiaP5/Zz3jrix7nUzfLJhV5w+09q55ZPvWnEj3Opm+WPC30MmD5lL6798HEjflzn/EVs2/lMHRKZWT240MeIrs4pfOs9rx7x4156zv/41bpZTrjQx5C3H3Egl7/vNTU9tnP+Ihe7WZNzoY8xJ7zkefxm/ptrfvzuYv/i9atTTGVmafBH/8eoiODQsxantr4Pv/FFfHrWS3yzEbM687VcbFDL7t/EO7/xu7pu40cfPJbjXnSAy94sBaMudEmzgC8DrcB3IuKCAfPHA1cArwYeBd4dEeuGWqcLvbk8sOkpXn/RTVnHeNbrZ0xl9pEH8foZU3nBvhNpafEvAzMYZaFLagXuAd4KrAeWAqdHxKqyZT4CHBURH5Y0F/iziHj3UOt1oTev7/9uHef898qsY4xZE8a1MHnCOPYe38ak8a3s1d7GpPbS9/HjWpjU3sb4thYmjGtlYnsr41rF+LZWxre1MK61NL21BdqT8dYW0Z58b2spfS//ahHJ99K4gJZkfPcwQIugRbuHxe4/uJ79Ttk0ds/b84u4/FfySP5YK9pfdk/t6GOv9tqvizjaqy0eA/RExJpkZVcBc4BVZcvMAc5Lhq8FviZJkdXxHBuV9xzXyXuO63x2fEffLs7/+Up+sOSP2YUaQ7bt3MW2ndvp3bI96yhWR/W4umk1hX4w8EDZ+Hrg2MGWiYg+SU8ABwCPlC8kaR4wD+CQQw6pMbI1WntbC59/x5F8/h1HVv2YiGDbzl30bNzKigefYOWDT7Dqwc3cu3ErW7b5qo42tp1zysy6rLeh10OPiAXAAigdcmnktq2xJDGxvZUjp+3LkdP2zTqO2ZhQzfvQNwDTy8anJdMqLiOpDdiX0slRMzNrkGoKfSkwQ9KhktqBucDCAcssBN6bDJ8G3Ojj52ZmjTXsIZfkmPgZwHWU3rZ4WUSslHQ+0B0RC4HvAt+X1ANsolT6ZmbWQFUdQ4+IxcDiAdPOLRveBrwr3WhmZjYSvpaLmVlBuNDNzArChW5mVhAudDOzgsjsaouSeoH7a3z4VAZ8CrVJNGsuaN5szjUyzjUyRcz1wojoqDQjs0IfDUndg12cJkvNmguaN5tzjYxzjcxYy+VDLmZmBeFCNzMriLwW+oKsAwyiWXNB82ZzrpFxrpEZU7lyeQzdzMyeK6+v0M3MbAAXuplZQeSu0CXNkrRaUo+k+Q3a5jpJd0m6Q1J3Mm2KpBsk3Zt83z+ZLklfSfItl3R02Xremyx/r6T3Dra9IXJcJmmjpBVl01LLIenVyb+zJ3lsVTdzHCTXeZI2JPvsDkmzy+adlWxjtaS3l02v+Nwml26+NZn+4+QyztXkmi7pJkmrJK2U9LFm2GdD5Mp0n0maIOk2SXcmuf51qHVJGp+M9yTzO2vNW2OuyyWtLdtfr0ymN+xnP3lsq6TbJf088/0VEbn5onT53vuAw4B24E5gZgO2uw6YOmDaRcD8ZHg+cGEyPBv4BaV74r4WuDWZPgVYk3zfPxnef4Q53gAcDayoRw7gtmRZJY89aRS5zgP+qcKyM5PnbTxwaPJ8tg713AJXA3OT4W8Cf19lroOAo5PhyZRudj4z6302RK5M91nyb9g7GR4H3Jr82yquC/gI8M1keC7w41rz1pjrcuC0Css37Gc/eewngR8BPx9q3zdif+XtFfqzN6yOiB3A7htWZ2EO8L1k+HvAO8qmXxElS4D9JB0EvB24ISI2RcRjwA3ArJFsMCJupnS9+dRzJPP2iYglUfopu6JsXbXkGswc4KqI2B4Ra4EeSs9rxec2eaX0Zko3Hx/4bxwu10MR8ftkeAtwN6X732a6z4bINZiG7LPk3701GR2XfMUQ6yrfj9cCb0m2PaK8o8g1mIb97EuaBpwMfCcZH2rf131/5a3QK92weqj/CGkJ4HpJy1S60TXA8yPioWT4T8Dzh8lYr+xp5Tg4GU4z3xnJn7yXKTmsUUOuA4DHI6JvwPQRSf68fRWlV3dNs88G5IKM91ly+OAOYCOlwrtviHX1uzk8sPvm8Kn/HxiYKyJ2769/S/bXlySNH5iryu2P5nm8BPgUsCsZH2rf131/5a3Qs/K6iDgaOAn4qKQ3lM9Mfqtn/v7PZsmR+AbwIuCVwEPAF7MKImlv4CfAxyNic/m8LPdZhVyZ77OIeCYiXknp3sHHAC9tdIZKBuaS9HLgLEr5XkPpMMqnG5lJ0inAxohY1sjtDiVvhV7NDatTFxEbku8bgZ9R+kF/OPlTjeT7xmEy1it7Wjk2JMOp5IuIh5P/hLuAb1PaZ7XkepTSn8xtA6ZXRdI4SqX5w4j4aTI5831WKVez7LMky+PATcBxQ6xrsJvD1+3/QFmuWcmhq4iI7cB/Uvv+qvV5PB44VdI6SodD3gx8mSz311AH2Jvti9It89ZQOnGw+yTBEXXe5iRgctnwbykd+/4C/U+sXZQMn0z/EzK3xZ4TMmspnYzZPxmeUkOeTvqffEwtB889MTR7FLkOKhv+BKVjhABH0P8E0BpKJ38GfW6Ba+h/kukjVWYSpeOhlwyYnuk+GyJXpvsM6AD2S4YnArcApwy2LuCj9D/Jd3WteWvMdVDZ/rwEuCCLn/3k8Sew56RoZvsrk2IezRelM9j3UDq2d3YDtndYsiPvBFbu3ialY1+/BO4F/rfsB0PApUm+u4CusnW9n9IJjx7gfTVkuZLSn+I7KR1P+0CaOYAuYEXymK+RfJK4xlzfT7a7HFhI/7I6O9nGasreTTDYc5s8B7clea8BxleZ63WUDqcsB+5IvmZnvc+GyJXpPgOOAm5Ptr8COHeodQETkvGeZP5hteatMdeNyf5aAfyAPe+EadjPftnjT2BPoWe2v/zRfzOzgsjbMXQzMxuEC93MrCBc6GZmBeFCNzMrCBe6mVlBuNDNzArChW5mVhD/D+7coXzaigN5AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(lossi)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([1.5000, 3.5000, 5.5000, 7.5000, 9.5000])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.linspace(1, 10, 10).view(-1, 2).mean(dim=-1)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(torch.tensor(lossi).view(-1, 100).mean(dim=-1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }