{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "\n", "torch.manual_seed(1024)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# 定义线性模型和sigmoid函数\n", "\n", "class Linear:\n", " # input: (B, in_features)\n", " # output: (B, out_features)\n", " \n", " def __init__(self, in_features, out_features, bias=True):\n", " # 对于模型参数的初始化,故意没有做优化\n", " self.weight = torch.randn(in_features, out_features, requires_grad=True) # (in_features, out_features)\n", " if bias:\n", " self.bias = torch.randn(out_features, requires_grad=True) # ( out_features)\n", " else:\n", " self.bias = None\n", " \n", " def __call__(self, x):\n", " # x: (B, in_features)\n", " # self.weight: (in_features, out_features)\n", " self.out = x @ self.weight # (B, out_features)\n", " if self.bias is not None:\n", " self.out += self.bias\n", " return self.out\n", " \n", " def parameters(self):\n", " # 返回模型参数\n", " if self.bias is not None:\n", " return [self.weight, self.bias]\n", " return [self.weight]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([5, 4])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "l = Linear(3, 4)\n", "x = torch.randn(5, 3)\n", "l(x).shape" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[tensor([[ 0.6459, -0.0353, 0.5852, 0.5732],\n", " [-1.0110, 0.2098, 0.4153, 0.0819],\n", " [-0.3151, -0.5068, 0.3941, 0.3839]], requires_grad=True),\n", " tensor([ 0.5583, -1.1253, 1.5603, 1.9050], requires_grad=True)]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "l.parameters()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class Sigmoid:\n", " \n", " def __call__(self, x):\n", " self.out = torch.sigmoid(x)\n", " return self.out\n", " \n", " def parameters(self):\n", " return []" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([3, 2])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "s = Sigmoid()\n", "x = torch.randn(3, 2)\n", "s(x).shape" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "class Perceptron:\n", " \n", " def __init__(self, in_features):\n", " self.ln = Linear(in_features, 1)\n", " self.f = Sigmoid()\n", " \n", " def __call__(self, x):\n", " # x: (B, in_features)\n", " self.out = self.f(self.ln(x)) # (B, 1)\n", " return self.out\n", " \n", " def parameters(self):\n", " return self.ln.parameters() + self.f.parameters()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([5, 1])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "p = Perceptron(3)\n", "x = torch.randn(5, 3)\n", "p(x).shape" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "class LogitRegression:\n", " # input: (B, in_features)\n", " # output: (B, 2)\n", " \n", " def __init__(self, in_features):\n", " self.pos = Linear(in_features, 1)\n", " self.neg = Linear(in_features, 1)\n", " \n", " def __call__(self, x):\n", " # x: (B, in_features)\n", " self.out = torch.concat((self.pos(x), self.neg(x)), dim=-1) # (B, 2)\n", " return self.out\n", " \n", " def parameters(self):\n", " return self.pos.parameters() + self.neg.parameters()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([5, 2])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lr = LogitRegression(3)\n", "x = torch.randn(5, 3)\n", "lr(x).shape" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.4861, 0.8908],\n", " [-1.2695, -0.6651],\n", " [ 1.6948, 1.6572],\n", " [-0.6641, -0.0216],\n", " [-0.9517, 0.0869]], grad_fn=)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "logits = lr(x)\n", "logits" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-0.4861],\n", " [-1.2695],\n", " [ 1.6948],\n", " [-0.6641],\n", " [-0.9517]], grad_fn=)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lr.pos(x)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "probs = F.softmax(logits, dim=-1)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[0.2015, 0.7985],\n", " [0.3533, 0.6467],\n", " [0.5094, 0.4906],\n", " [0.3447, 0.6553],\n", " [0.2614, 0.7386]], grad_fn=)\n", "tensor([1, 1, 0, 1, 1])\n" ] } ], "source": [ "pred = torch.argmax(probs, dim=-1)\n", "print(probs)\n", "print(pred)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.4122, grad_fn=)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss = F.cross_entropy(logits, pred)\n", "loss" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[1.6020, 0.2250],\n", " [1.0403, 0.4359],\n", " [0.6745, 0.7121],\n", " [1.0651, 0.4226],\n", " [1.3416, 0.3030]], grad_fn=)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "-probs.log()" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.4122, grad_fn=)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "-probs.log()[range(5), pred].mean()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import make_blobs\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "data = make_blobs(200, centers=[[-2, -2], [2, 2]])\n", "x, y = data" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/OQEPoAAAACXBIWXMAAAsTAAALEwEAmpwYAAAg3UlEQVR4nO3dXYwc13Un8P+ZZilsWg5HhiYI1PoggQ0kWKuIAw2cBfiQSHYkJ5KlgZVEySIBkjwICySAJTB0aBtYUVgHZsAk8gIJEAjehwVMbEhFXkaKHDA2qCywAmR4qOHES0sMnDiS3XYQBlLbsact9sycfZipYU/NvVW3qm593O7/DxBsNnuqq2uIU7fOPfdcUVUQEVG4Zpo+ASIiKoeBnIgocAzkRESBYyAnIgocAzkRUeD2NPGhN954ox44cKCJjyYiCtaFCxf+TVXnkq83EsgPHDiApaWlJj6aiChYIvKm6XWmVoiIAsdATkQUOAZyIqLAMZATEQWOgZyIKHCNVK0QTaKzy32cPHcZ3xkMcdNsF0cfuB2L872mT4umAAM5kQdnl/v4xBe+huFoHQDQHwzxiS98DQAYzKlyDOREHpw8d3k7iMeGo3WcPHeZgdwjPvWYMZATefCdwTDX65Qfn3rsONlJ5MFNs91cr1N+aU89046BnMiDow/cjm7U2fFaN+rg6AO3N3RGk4dPPXYM5EQeLM738JmP3oXebBcCoDfbxWc+etfUP/L7xKceO+bIiTxZnO8xcFfo6AO378iRA3zqiTGQE02ZUCs/4nMM8dyr5i2Qi0gHwBKAvqo+5Ou4RORP6JUffOox85kj/xiA1z0ej4g8Y+XHZPISyEXkZgAPAvicj+MRUTVY+TGZfI3IPwvg4wA2bG8QkcdFZElElq5cueLpY4koD1Z+TKbSgVxEHgLwr6p6Ie19qvqsqi6o6sLc3K4t54gImznswyfO4+Cxl3D4xHmcXe57PT7r3SeTj8nOwwAeFpFfBLAXwI+LyOdV9dc9HJtoatQxEcnKj8kkqurvYCI/B+D3sqpWFhYWlJsvE+10+MR59A256t5WsGXwJRG5oKoLyddZR07UErYJx3hkXuVIvWhteag16ZPG6xJ9Vf071pATFWObcOyIVFoyGKd0+oMhFNduFFn5+aI/R/6x1wpRS9gmItct6U9fJYNFa8tZk94eTK0QNSiZmnj0nh5efuPKjlTFyXOXjblzXyWDRWvLm6xJZ0pnJwZyooaYqlSev9A3dk2sslnUTbPdQjeKoj9XVuhtBqrA1ApRQ7JSE3FN+ZOnL2JvNIPZblRJi9yiteVN1aQzpbMbR+REDUmrUkmOOt9ZHaEbdfDMY4dKBfC0lETeVEVTNelsM7AbAzlRQ2ypCQHw9IuXvG7mfHa5j+MvXMJgONp+LZmSKHLcJroRNpXSaTOmVogacvSB2yGG1xWbI3CTIqPOeHQ/HsRjIaYk2GZgNwZyooYszveQd111kVGnKac8LrSUBLfV242pFaIG9SxpgtluhHfXNrxUqmQF6hBTEtxgYieOyIkaZEsTHH/4Tm+jzrRAPe0piUnBETlRxdIqRbIqP3yMOk2bFgPADfsiPPWROzmynQAM5EQVclm8UnWagK1rJx8DOVGF0hav1BlImVNOF/qSfwZyogrVsXgl9CDUtElY8s/JTqIKVb1HJlvJljcJS/4ZyIkMfO2dWfXiFVsQeuL0xUr2/JxEk7Dkn4GcKMHnKLfqxStpwaY/GOLJ0xdxoKKNnCdF1U9NdWCOnCjB9wRllRONtr4jsXjlaIh537qYyjNDq6/niJwoIaRHbVPqxia0vG9dJmHJP0fkRAkhddcbrxFPG5nH2ngzaoPQyzMZyIkS2vyobSs1XJzv7SqjM2njzYjKYyAnSmhiJaRLLXhWvXNydC7Aju6KbbgZ+ax5Hz/W/m4EEWCwOprKWnpRyw7dVVpYWNClpaXaP5eojUwj6TgI98aC0uET543pk95sF68cu8943DYtFDJ9z27UKZSPznr6KHrcthORC6q6kHydI3KihpmqZEzVJnknYduW9/VZDZTVY72JNghNYtUKUcOyJiDjoBR6vbPPaiCXn5mmiV0GcqKGuQTi7wyGxlJDAXDvHXMVnZlfPm9ELj8Tyg3OBwZyooa51ILfNNvF4nwPj97T27HPpwJ4/kI/iFWbPm9EWdfMNLHrq+1CGzGQEzVsfEEKgF0bMo8HpZffuLJrn882LPRxCZI+b0TJRTyz3Qg37IusC3omvbkYJzuJWmB8YjKt2qSNq07ztIFNuxHlnZjMM5nblr7wVWEgJ2qZtADV1KrTtJtLniDZ1I3Idvz+YIiDx15qzVqBohjIiQJS56rTOPAkFxclR9x5gnNTN6K05mLjqRagmqZiVW9ewRw5UUDqavA0nlMGkJqXz1ONUnV/dhuXCeUq5xqq3ryCI3KaWm1b+egqLfXi6ztlLbgBro248zwluLQ/qOL3Ev/88RcuYTAcZX4n36pOKTGQ01Rqep9GW7DKE8SS7733jjk8f6Fv/E5Avt4xLgEmHnEne7x0RHaMNpOfk3Ujqur3sjjfw8lzl1MDeVUpnqpTSgzkNJWarGKwBaulN9+2BuJkkN/fjfDDq2sYrev2e0+9+pYxBXL8hUt4d20jV3DsRjNYHW1Yv0NyxB0fp2wQrvr3knaDqjLFU/XcRukcuYjcIiIvi8jXReSSiHzMx4kRVanJMj5bsPpfX/mWNYgl66AHw9F2EI/Z2t8NhqNc+3qeXe6nBnFbXt5HHrjq34ttBNwRqbTJVtVzGz5G5GsAjqjqayLyXgAXRORLqvp1D8cmqkSTm0fYgtK6pRNpfzDEkTMr1r8vwzRqTgu8Ahg7LQJ+gnDVvxfbyLiOTolVNjErPSJX1e+q6mtb///fAbwOoP0zRjTVmqqeANJHhSYCe5D3ITlqTttpKC2g+uilUvXvZRK2dTPxmiMXkQMA5gF8xfB3jwN4HABuvfVWnx9LlFuVm0dkTVjaRoWP3tPbkSOPuYbw5EYSeYyPmjsi1htHWl8UH3ngOjb1aFt7Xx+8bSwhItcD+D8A/kBVv5D2Xm4sQZPKdfOEtKqVp1+8hHdW7ZUVsWhGcP3ePdu74tx7x5xxwtPF+OYUB469ZH1fVhoi1JLOUFS6sYSIRACeB3AqK4gTTTLXqou0UeGPUiYaYx0RnPzlu43HSAbzbtTB3mjGenOIZgSrV9e2l6rPdiNriV5WBckkjnZDUDqQi4gA+B8AXlfVPyl/SkTVqGO0WHbCz2UhTtqo+NOLd2Hhtvft+p4AjFujdaMZrG3odpDvD4aIOoJoRjDaMI/tp2nDhlD4GJEfBvAbAL4mIhe3Xvukqn7Rw7GJvKhrAVDZqguXIPlje9JrFNJGxckAHy/iGTdaV9ywL8L3h2vGXPk0bdgQCh9VK/9XVUVVf1pVD239xyBOuVTd9L/qXhexslUXLkFyMBx566Vtu3EMVkf441+5u7HKHsqHTbOocXU0/a9rAVDZ8rajD9y+a2MJk7w3Ids13t+NjO+PdySaxFK9ScQl+tS4OpbL17kAqMyE3+J8D0tvvu1UfZLnJmS7xnujGXSjjrVkkJOXYeCInBpXx2jZx0KTuvZ8/PTiXXjmsUPbI2HbQqE8N6G0FApH3eHjiJwaV8douexCEx+TpXmqZpJbv5VdaJN2jX2NullD3hwGcmpcXbvelAlYZdM/eW8EyaD46D09vPzGlcJB0vc1ztNC1+X7MeiXw0BOjatjWXZZVdSH224EpqD//IV+qZRHmWvsErRtLXRdv1+dveAnEQM5tULbJ9Wqqg83vV7V5G+Ra2wKuqagbZuYrfP7TTNOdhI5qKo+3PR6k73Sk0xBN08vl/HvF08W27orcsVocRyREzkom/7Jk6OuevI3T366THCNZmT7+5kmbJPyfj/m2a9hICdyVLY+HHC7EVQ5+Zs3P227qSRb5kYzgg0A6+P9WcaqJrN6yBQpBWWe/RpvbWzzYBtbot2Se3KKYLtFra/RZlpqA9isIx//LFvpY7KKZvXqmrG7Ytwe9+Cxl6wpmeRnlvke4+14J1GlbWyJqJxkwBwMR+hGHTzz2CGvI8ysVElyZOv6JHHQ0sM8/jzbyH488PpI+Uxrnp2BnKgF6qrksAXU5Ocef+HS9ue6pJT2W3qYx71cstJFvlI+09qZkVUrRC1QdIRpahuQ1krAVH1jMhiOcrUgsHQR2H49qwFX3u6UTe652kYckRO1QJERpmkUe/S5FUA2e4rHr5lSJUfOrGRu6Jx8GkhLfQwsuw+Nv542ss97IwthEVmdGMiJWqBIpYppFGva1SeZoon/9+hzK9ZdgICdQTQr9eFyI0q7ERS5kbV9EVmdmFohylBH18Mivb/zTOwZ35vR+Hw8iGalPrJSHVk955kqKYcjcqIUddUrF1nc4jJxOf7ecSfPXd5Ov5gkg6jtptEfDHH4xPntksm90YyxZDJrMpepknIYyIlSVFVNMh64Z/dF+MGP1rbTHK43C1M6JpqRHTlywDyyTRvNm+q60xYGxa+nlUy65MCZKimOqRWiFFXUKyfTDO+sjnblql22cjOlY07+8t04+Ut3Z6ZobLnnuK47+X5T6iO5ujPtvPP0mqH8OCInSlFFvXLWcvWYy83CNorNGtnmnVw1pT7yNL/y0XaAvVXsGMiJUlTR98R1NF/laLVITjp507Atkzeddxt2aJpkDOREKaqYhHOZpExWfFQxEi2bky4yqm9qh6ZJx0BOlMH3JJxxkrIjeM91e/C94c6Kjzwj0bSAX8XNoM5KE/ZWScfuh0QNcA2srl3+bF0KP/PRuwDAWN1y/d49GKyOKuu06NO0djtMYvdDohZxHeW7jkSzFuyYVoDGbWfHm131B0M8efoilt58G59evCv7i9Skrg26Q8VATk5YMZCtimtky6fPiODscn/7+D5TDwrg1KtvYeG297Xmd8wFQ+kYyCkTKwayVXWNTCNRAFhXzdXrxHUFaEyxu2lW07hgyI4LgihT3haj08h2jY6cWSnVmyVe9NMx9Il17XXi2ro2iROJ4eCInDKxYmBTWurEdi2SI+ciFud7ePL0RePfxZ/rknoY30buh1fXUnutAM1s9kzFMJBTJu7GUryNK+Cn3tnld5CWekj+3Xhw3RvNYDja2PH+pjZ7pmKYWglYHe1VAbYYBYq1cR1X9umlzO/A9O9kcb6HV47dh2+eeBCv/7dfwGcfO5Srha4rpuXqwRF5oOoc6bBiIDu9lLXzTtmnl6K/A9d/J1VNJDItVw8G8kDVvWR52isGXFMbwO7FN76eXor8Dppe2s60XD2YWgkURzr1ck1tFNnpp0pN/zthWq4eHJEHiiOdeuVJbbTp6aXpfydMy9XDSyAXkQ8D+O8AOgA+p6onfByX7LhkuX5tCtCu2vDvJMTrFprSgVxEOgD+DMDPA/g2gK+KyAuq+vWyxyY7jnTIBf+dTAcfI/IPAPiGqv4TAIjIXwB4BECjgTy0RQgu52t6zzR1fqNiOCKefD4CeQ/At8b+/G0AP5N8k4g8DuBxALj11ls9fKxdaIsQXM43tO9ERPWprWpFVZ9V1QVVXZibm6v0s0JbhOByvqF9JyKqj49A3gdwy9ifb956rTFNl1zl5XK+oX0nIqqPj0D+VQA/JSIHReQ6AL8K4AUPxy3MVlrV1tI8l/MN7TsRUX1KB3JVXQPwuwDOAXgdwBlVvVT2uGWEtgjB5XxD+05EVB8vdeSq+kUAX/RxLB9CK7lyOd/QvhMR1YebLxMRBYKbL0+Z0Oroiag4BvIJxJpzounCQD6B6m5dytE/UbMYyCdQnTXnHP0TNY/9yCdQnTXnXHFK1DwG8glUZ805V5wSNY+BfALVuUsNV5wSNY858glVV+tSl40LOBlKVC0Gciola8UpJ0OJqsdATqWljf6b3sWdaBowkLfIJKYgOBlKVD0G8pYom4Jo602g6V3ciaYBq1Zaokw9dnwT6A+GUFy7CZxd9r+/x9nlPg6fOI+Dx17C4RPnMz+D7XeJqscReUuUSUFk5aGTo/V775jDy29cyT16L/LU0Ib2u219WiHyhYG8JcqkINJuAqbg+/lX39p+T54UTtGJyyZ3cWfVDE0DplY8yZtySCqTgkhblGMKvkmuKZwQJy7ZQoCmAQP5ljKB2EeOusxqzLSbgGuQdXlfiKs4Q7z5EOXF1ArKP377qpUumoJIy0OfPHfZmLJJcgnGLqs424ZVMzQNGMhRPhC3YdRnuwmYgm+SazBuw8RlXiHefIjyYiBH+UBc9ajPVHUBuAVUU/A1Va0AwOET552O1+bAnRTizYcoL26+jM0AZgrEvdkuXjl2X+bPJ1MzwOaoz0fHQdOxTcp8XpXnn+ccGGyJ0tk2X+ZkJ8yThdGMYPXqmtPkZ5VtY12qToBylRhNV3bUuaCJaBJNfWolHgkOR+voiGBdFbPdCD+8uoZ3VkcA3Be+VDGCzJNnL5qTbzrHz8ZaROVMVSA3rXB8/kJ/O4isq6IbdSACjNZ3ppyaCiy2/LvtvT4/o8rKjvHfhS25xxJBIjdTk1oxPb6fevUt40gwHoknNRFYTGkfkzKVGHX3Q0n+LmxYIkjkZmpG5KbH97zTvE0ElvgJ4OkXL+26wQg2v0PPMjnoOoFYd2WHS96fJYJE7qYmkOcZTc92I7y7ttGa2uM4/56nsiPvIqc6ywrTfhcCsGqFKKdgA3neoDazNZGZFI9qY92og+MP3wmgnR37XM+higlEXyWCtpx8styTJYlEboIM5HlGm/F7TUG8G3Xw6D09a0vXNnXsO/rcCp5+8RIGqyOnoOa7EsVnF0HXDZvZtZDITZCBPM9o05aP7YjUuuAlD9M5jzY0Vzmk70oUnyN8l5w8SxKJ3AUZyPOMNm3v3VAtFRB8PfabjuMyas4Kar57jPge4WeliZqubScKSZCBPM9os4oaaR+P/WeX+7sqUeLj7O9GGAzNJZDj0oKa70qUumvN2bWQyF2QgTzPaLOK7ne2x/6nX7y03TY2XiVqKg1M658yHK1jbzSDbtTJLNHLCmo+K1Hq7iLIroVE7oIM5HlGm1XUSNtGwu+sjrZH2PHkqmm0nlVHPVgd4ZnHDm2f8/6tlgHjq03rCGrJtE/axHDRY7altp0oZKW6H4rISQAfAXAVwD8C+C1VHWT9XNu6H+Zl65aYZry07uCxl1IXI5m6LtZdildFR8Q2dFkkCpmt+2HZQH4/gPOquiYifwgAqvr7WT/nM5A3UWvs2lp2nAD45okHcXa5jyNnVozlkEA9gc3lmrm09s177W3H7IhgQ5WjbqIMtkBeKrWiqn879sdXAfxSmePl1VStcXzsJ05fdP6Zm2a7qTXtwOaK0uMP31nrSNt2zbKqRopce9sx09JQRJTNZ9Os3wbwN7a/FJHHRWRJRJauXLni5QOb7KO9ON9Dz7GCIs5np9W0f/axQ7j41P2VBzDbNTv+wqUdm0/v70bGn48nWItce5eKE+5wT5RfZiAXkS+LyP8z/PfI2Hs+BWANwCnbcVT1WVVdUNWFubk5Lydfda3x2eX+juCW3OjApTPhbDfaTpVUVdNuYzp/W25/MBzt6Az5w6triGZkx3vGJ1iLXHvXTo6sFSfKJzO1oqofSvt7EflNAA8B+KDWvG+crdZYARw49pK1K6ALl9RB/L+2nPdsN8LFp+6/9ud9kbFFrm2kWib/b1vm72q0rrhhX4R91+0xfn6ROu9kJYqt/41iM5/OfDmRm1I5chH5MICPA/hZVV31c0rusnaIL5NzdV0ivjjfw5OWXPn3xhb1nF3u4wc/Wtv1nqgjxjJC241k6c23nUoAbcv88xisjrD8X+83/l3ROu/x2va0SWPmy4nclc2R/ymA9wL4kohcFJE/93BOzuK9Mjsi1vcUzbnmSR3YRqHjr588d9kYSN9z3R7nQDwcrePUq2857W3pIz2RNboe36d0ththbzSDJ09fzNzj1HQME+bLidyUCuSq+h9U9RZVPbT133/xdWKuFud72MjI6BQJai7BOZa1w05abvp7lqX4tnNOflNbsMuzlP2GfVGhHYIW53t45dh9eOaxQ3h3bQPvrI5yb54cH8N2K2a+nCjbRGz1lhW0ivTnyLP9WXJ02pvtbk9wxumDvOeW55xNwc50/tGMIOrsnsB86iN3Ws/fZnwi9ciZldLVQz6uA9G0KrUgqCjfKzvTcq1lFtiUXWyUtfgnbas203dKboIRM60EtZ0/UH7Zu+uCqHgRVNFjRh3Be67bg+8N3XqwE026ShYEtcV4NURWw6q8xy27HN0WxIFrQTmtImY86N57xxyev9B3nmC0nX/ZYOiy5ybgPpqObzjD0fr27+6GfRF+8KO17S6QnPwkspuIQA7Uu+ekC9dgF7NVxCS/08Jt72u8kZRL3tq1qVdyJL6uup0SSk4Oc2MJIrOJCeRtU2SSLvkzaft2xn/35OmLOHnuspeA7ppKstWQF+mZYqvOsd0EOflJtFswgTy0jXhtwS7rZ2KmOvInT1/E0ptvY+G29zn1Ocm7QbVr7xRbDXmRuYi8gZmTn0S7BRHIQ9yIN2uxUlIyFWEaqSqAU6++hb9e+W5qlUg8VzA+OZp1zfLskemzV7jthjfbjfDu2gY3liByEEQgD3EjXtfl6IC5aiWtjty2DVwcrONrZas5N12zvL1TfM1J2Eb3xx++EwA3liByEUQgD3Uj3qzl6GnpiCKpmY5I5hOA7ZoV6Z3iI92VNbpn4CbKFkQgb+tGvHkCmUs6Yvx4s/vMbWRtXPb4BOzXLG/vFJ/priYqjkKbcyFKE8SCoDZuEeb7nEzHmxHApc9VnJqJc+M20Yzg+r17MFg1L7DJE9xcdhBqqzb+eyJyEfSCoDZuxOs7b2863oYCIkDavXZ2awOIJ05fhKl3WDzhObu1gXPcRte2CMn13NPSXW0f7YY450KUJohADvh//C4bbHzn7a2Tmxkj8sFwtD35mXzvDfsiPPWRza3jDp84v2uStEzwsqW79nej1lcYhTrnQmQzEU2z8oofrV3awdqkNXnK2lkoz/HK+P7wWv9z38HL1lRMBI1tv+eKDbpo0kxlIPex16ctkN17x1yhm4TrNmh5rKtuf7bv4GXr+Dgw7IAEmG8YRW54PuTpbEkUgmBSKz75GJ3a8vZF86956s6BzbQJAOPWcabPLrqjTxpTuss24Zq8YTS5yKuNcy5EZUxlIPdVzmgKZLZt31xuEnnrzl3byX5nMKwteLneMJqecGxbkzWiMoIN5GUmK6sYncZ83iSA9MCbbN+bdk7x+9sy2uWEI5E/QQbyso/lVY5Ofd4kXAKvyyi+7tyvy3m3dZEXUYiCDOQ+HsurGp02mX+t47N91YgffeB2HP3LFYzWr80DRB3hhCNRAUEG8rY/ljeZf63ys01PQkefW8HTL16yrhZNlZzLrX+RMdFECLL8kHXAzTA9CY02FO+sjnLX4588d3nXDkCjDW1VvTlRKIIckVc5WRmCs8t9PP3ipe3Sw9luhOMP31n5U4DLE89wtI4jZ1YApM9XtP2piigkQY7IbYtRpqGc7OxyH0f/cmVH/fhgOMLR51YqX1Dj+sQzvhAp77H4VEWUX5CBHNgM5q8cuw/fPPEgXjl231QEcWArJbG+O5lcR1oiz+rTrJWyXF1J5E+QqZVplpZ6qDotkayK2b/VUdF0Y8k6H66uJPKHgTwwaTsH1ZGWSFbFnF3u48iZFWM7gazz4epKIj+CTa1Mq6MP3I6os7vxeDTTTA324nwPf/wrdxdOkzTVOItoknBEHph4BFu2asXn5g9F0yRNNs4imiRBbPVG6fIG5bZsdRbydnFETQh6q7c0bd9WrGpFRrVNdx6MsZacyI+gc+Q+dvoJXZFNMmyTpWkdFKvAWnIiP4IO5D52+gldkVFtx7RLc8rrVWEtOZEfQQdyPpoXG9Xadh5K25GoCtO8QpfIp6Bz5OxpXazvTM9y3XoNXDfWkhOVF/SInI/mxUa1vG5Ek8XLiFxEjgD4IwBzqvpvPo7pgsu8N+Ud1fK6EU2W0nXkInILgM8BuAPAPS6BnHXkRET52erIfaRWngHwcXB/FyKiRpQK5CLyCIC+qq44vPdxEVkSkaUrV66U+VgiIhqTmSMXkS8D+EnDX30KwCcB3O/yQar6LIBngc3USo5zJCKiFJmBXFU/ZHpdRO4CcBDAimwuJLkZwGsi8gFV/RevZ0lERFaFq1ZU9WsAfiL+s4j8M4CFOqtWiIgo8DpyIiLyuLJTVQ/4OhYREbnjiJyIKHAM5EREgQu6aRaFY9o3ACGqEgM5VY57cxJVi6kVqhw3ACGqFgM5VY4bgBBVi4GcKse9OYmqxUBOleNGFkTV4mQnVY4bWRBVi4GcasG9OYmqw9QKEVHgGMiJiALHQE5EFDgGciKiwDGQExEFTlTr3z5TRK4AeLPij7kRAHcr4nWI8TrwGsRCvg63qepc8sVGAnkdRGRJVReaPo+m8Tps4nXgNYhN4nVgaoWIKHAM5EREgZvkQP5s0yfQErwOm3gdeA1iE3cdJjZHTkQ0LSZ5RE5ENBUYyImIAjcVgVxEjoiIisiNTZ9LE0TkpIi8ISJ/LyL/W0Rmmz6nuojIh0Xksoh8Q0SONX0+TRCRW0TkZRH5uohcEpGPNX1OTRGRjogsi8hfN30uPk18IBeRWwDcD+Ctps+lQV8C8B9V9acB/AOATzR8PrUQkQ6APwPwCwDeD+DXROT9zZ5VI9YAHFHV9wP4TwB+Z0qvAwB8DMDrTZ+EbxMfyAE8A+DjAKZ2VldV/1ZV17b++CqAm5s8nxp9AMA3VPWfVPUqgL8A8EjD51Q7Vf2uqr629f//HZuBbOqaw4vIzQAeBPC5ps/Ft4kO5CLyCIC+qq40fS4t8tsA/qbpk6hJD8C3xv78bUxhABsnIgcAzAP4SsOn0oTPYnNQt9HweXgX/A5BIvJlAD9p+KtPAfgkNtMqEy/tOqjqX22951PYfMw+Vee5UTuIyPUAngfwhKp+v+nzqZOIPATgX1X1goj8XMOn413wgVxVP2R6XUTuAnAQwIqIAJvphNdE5AOq+i81nmItbNchJiK/CeAhAB/U6Vk80Adwy9ifb956beqISITNIH5KVb/Q9Pk04DCAh0XkFwHsBfDjIvJ5Vf31hs/Li6lZECQi/wxgQVVD7XpWmIh8GMCfAPhZVb3S9PnURUT2YHNy94PYDOBfBfCfVfVSoydWM9kcyfxPAG+r6hMNn07jtkbkv6eqDzV8Kt5MdI6ctv0pgPcC+JKIXBSRP2/6hOqwNcH7uwDOYXOC78y0BfEthwH8BoD7tn7/F7dGpjQhpmZETkQ0qTgiJyIKHAM5EVHgGMiJiALHQE5EFDgGciKiwDGQExEFjoGciChw/x/repbJhkPWUgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.scatter(x[:, 0], x[:, 1])" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "step 0, loss 0.26484909653663635\n", "step 200, loss 0.08326209336519241\n", "step 400, loss 0.051809437572956085\n", "step 600, loss 0.038577087223529816\n", "step 800, loss 0.031144658103585243\n", "step 1000, loss 0.026323309168219566\n", "step 1200, loss 0.022914284840226173\n", "step 1400, loss 0.020361438393592834\n", "step 1600, loss 0.01836979016661644\n", "step 1800, loss 0.016767438501119614\n" ] } ], "source": [ "batch_size = 20\n", "max_steps = 2000\n", "learning_rate = 0.01\n", "x, y = torch.tensor(data[0]).float(), torch.tensor(data[1])\n", "lr = LogitRegression(2)\n", "lossi = []\n", "\n", "for t in range(max_steps):\n", " ix = (t * batch_size) % len(x)\n", " xx = x[ix: ix + batch_size]\n", " yy = y[ix: ix + batch_size] # (20)\n", " logits = lr(xx) # (20, 2)\n", " loss = F.cross_entropy(logits, yy)\n", " loss.backward()\n", " with torch.no_grad():\n", " for p in lr.parameters():\n", " p -= learning_rate * p.grad\n", " p.grad = None\n", " if t % 200 == 0:\n", " print(f'step {t}, loss {loss.item()}')\n", " lossi.append(loss.item())" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/OQEPoAAAACXBIWXMAAAsTAAALEwEAmpwYAAAgCklEQVR4nO3deXRc9X338fdXo93abEneF3kFHMwWYQik0AWICSnQhiZkaZzlHJonoW3K6emB0iZ5aNoScp407ROawJPwkJCFJC1pnGLqsGVpwMHyAsY2tmXjRbZsy5KtxbK2me/zx1z7GYmRNbJmka4+r3N0dOd375371Z3RR1e/+5t7zd0REZHwyst1ASIiklkKehGRkFPQi4iEnIJeRCTkFPQiIiGXn+sChqqpqfG6urpclyEiMqFs3LjxuLvXJps37oK+rq6OhoaGXJchIjKhmNn+4eap60ZEJOQU9CIiIaegFxEJOQW9iEjIKehFREJOQS8iEnIKehGRkAtN0Hf3DfDln+1k84ETuS5FRGRcCU3Qn+6L8i8vNLL1UHuuSxERGVdCE/Rn6D4qIiKDhSbozSzXJYiIjEuhCfozdGtEEZHBQhP0Op4XEUkuNEF/ho7nRUQGC03Qq4teRCS50AT9GeqiFxEZLKWgN7NVZrbTzBrN7N4k8+8xs+1m9pqZPW9mCxLmRc1sS/C1Jp3FD6pBvfQiIkmNeIcpM4sADwM3Ak3ABjNb4+7bExbbDNS7e7eZ/Q/gIeD9wbzT7n5Zesseng7oRUQGS+WIfiXQ6O573b0PeBK4LXEBd3/R3buDh+uBuektMwU6oBcRSSqVoJ8DHEx43BS0DecTwDMJj4vNrMHM1pvZ7clWMLO7gmUaWlpaUihpeBpHLyIyWFpvDm5mHwbqgesTmhe4+yEzWwS8YGZb3X1P4nru/ijwKEB9ff15JbVG3YiIJJfKEf0hYF7C47lB2yBmdgNwP3Cru/eeaXf3Q8H3vcDPgcvHUK+IiIxSKkG/AVhqZgvNrBC4Exg0esbMLgceIR7yxxLap5pZUTBdA1wLJJ7ETRsd0IuIJDdi1427D5jZ3cA6IAI85u7bzOwBoMHd1wBfAsqAHwUXFzvg7rcCFwGPmFmM+B+VB4eM1kk7ddGLiAyWUh+9u68F1g5p+2zC9A3DrPcSsGIsBaZKV68UEUkufJ+M1Uh6EZFBQhP0Op4XEUkuNEF/hrvG0ouIJApN0J/pov/HZ95g4X1rz72wiMgkEpqgFxGR5EIT9Lp6pYhIcqEJehERSS40Qa9h9CIiyYUm6EVEJDkFvYhIyCnoRURCLjRBrz56EZHkQhP0IiKSXGiCXuPoRUSSC03Qi4hIcqEJevXRi4gkF5qgFxGR5EIT9DqgFxFJLjRBLyIiyYUm6IfeM3bdtiO0dvXmqBoRkfEjNEE/1J88sZGPPb4h12WIiORcaII+WR/9gbburNchIjLehCboRUQkudAEfbJx9LpHuIhIiIJeRESSC03QDx11IyIicaEJehERSU5BLyIScgp6EZGQSynozWyVme00s0YzuzfJ/HvMbLuZvWZmz5vZgoR5q81sd/C1Op3Fi4jIyEYMejOLAA8DNwPLgQ+Y2fIhi20G6t39EuDfgIeCdacBnwOuAlYCnzOzqekrf2itgx+7xleKiKR0RL8SaHT3ve7eBzwJ3Ja4gLu/6O5nPoa6HpgbTL8LeNbd29z9BPAssCo9pYuISCpSCfo5wMGEx01B23A+ATwzmnXN7C4zazCzhpaWlhRKSk4DLEVE3iqtJ2PN7MNAPfCl0azn7o+6e72719fW1qatno6eAb7y3K60PZ+IyESUStAfAuYlPJ4btA1iZjcA9wO3unvvaNZNl2QfmvrKc7sztTkRkQkhlaDfACw1s4VmVgjcCaxJXMDMLgceIR7yxxJmrQNuMrOpwUnYm4K2jIjGdPJVRGSo/JEWcPcBM7ubeEBHgMfcfZuZPQA0uPsa4l01ZcCPgqPqA+5+q7u3mdnfEf9jAfCAu7dl5CcREZGkRgx6AHdfC6wd0vbZhOkbzrHuY8Bj51ugiIiMjT4ZKyIScgp6EZGQU9CLiIScgl5EJOQmRdDrmjciMplNiqAXEZnMJkXQ90ddR/UiMmlNiqBf9jfP8NmfbMt1GSIiOTEpgh7gifX7c12CiEhOTJqgFxGZrBT0IiIhp6AXEQk5Bb2ISMgp6EVEQk5BLyIScgp6EZGQm1RBf7yrd+SFRERCZlIFff0XnqOnP5rrMkREsmpSBT1AXzSW6xJERLJq0gW9iMhkM+mCXhexFJHJZtIFvYjIZKOgFxEJOQW9iEjIKehFREJu0gX9t17ax9am9lyXISKSNZMu6L/87C5+/6v/nesyRESyZtIFvYjIZKOgFxEJuZSC3sxWmdlOM2s0s3uTzL/OzDaZ2YCZ3TFkXtTMtgRfa9JVeDKffc/yTD69iMiElD/SAmYWAR4GbgSagA1mtsbdtycsdgD4KPCXSZ7itLtfNvZSR1ZbXpSNzYiITCgjBj2wEmh0970AZvYkcBtwNujdfV8wT1cMExEZZ1LpupkDHEx43BS0parYzBrMbL2Z3Z5sATO7K1imoaWlZRRPPfR5Ul/2WEfPeW9HRGQiycbJ2AXuXg98EPiKmS0euoC7P+ru9e5eX1tbe94bGs0Fy1b+w/O06kYkIjIJpBL0h4B5CY/nBm0pcfdDwfe9wM+By0dRX0ad6O7PdQkiIhmXStBvAJaa2UIzKwTuBFIaPWNmU82sKJiuAa4loW9fREQyb8Sgd/cB4G5gHbAD+KG7bzOzB8zsVgAzu9LMmoA/Ah4xs23B6hcBDWb2KvAi8OCQ0To5NZo+fRGRiSqVUTe4+1pg7ZC2zyZMbyDepTN0vZeAFWOsMWWjDW7dhEREJoNQfTJWwS0i8lahCvrR+s76/TTsa8t1GSIiGTWpg/7xl/Zxx9dfznUZIiIZFaqg18lVEZG3ClXQq49eROStQhX0IiLyVgp6EZGQU9AD7/jH5+nqHch1GSIiGaGgB5rbe9h+uCPXZYiIZISCXkQk5BT0IiIhp6APuMZmikhIKegD33p5Hy++cSzXZYiIpJ2CPrB26xE+9viGXJchIpJ2CnoRkZBT0IuIhJyCfohoTCdlRSRcFPRDLP7rtext6cp1GSIiaaOgT2J7sz4lKyLhoaAXEQk5Bb2ISMgp6JO4+3ubuf/HW3NdhohIWijoh/Hd3xzIdQkiImmhoBcRCTkF/Tl09PTnugQRkTFT0J/DJZ//GZsPnMh1GSIiY6KgH8HWQ+25LkFEZEwU9CPQZepFZKJT0I/gqU1NrHn1cK7LEBE5bykFvZmtMrOdZtZoZvcmmX+dmW0yswEzu2PIvNVmtjv4Wp2uwrPl1aZ2/uz7m3NdhojIeRsx6M0sAjwM3AwsBz5gZsuHLHYA+CjwvSHrTgM+B1wFrAQ+Z2ZTx162iIikKpUj+pVAo7vvdfc+4EngtsQF3H2fu78GxIas+y7gWXdvc/cTwLPAqjTUnXUPv9hI70A012WIiIxaKkE/BziY8LgpaEtFSuua2V1m1mBmDS0tLSk+9VtdNKvivNcdyZfW7eSb//1mxp5fRCRTxsXJWHd/1N3r3b2+trb2vJ9nyfQydn4hc/8wnO7TEb2ITDypBP0hYF7C47lBWyrGsu55KcqPZOy5n3n9CD/e3JSx5xcRyYRUgn4DsNTMFppZIXAnsCbF518H3GRmU4OTsDcFbRNS47Eu/uIHr+a6DBGRURkx6N19ALibeEDvAH7o7tvM7AEzuxXAzK40sybgj4BHzGxbsG4b8HfE/1hsAB4I2kREJEvMx9lHP+vr672hoWFMz1F379NpqmZ43/74Sq5bdv7nE0RE0snMNrp7fbJ54+Jk7ET0zOtHcl2CiEhKFPTnacO+Nv5LYS8iE0Aog/4vblhGWVF+RrfReKyLT35nY0a3ISKSDqEM+j+/YSmP/vHbs7KtjfvbGIgO/UCwiMj4Ecqgz6b3fu1lvvLc7lyXISIyLAV9Gmzcf4LjXb25LkNEJCkFfRq8vLeV+i88l+syRESSUtCnUeOxTvXXi8i4E9qgv3z+VC6dV5XVbd7w5V/y0LqdWd2miMhIQhv0JYURfvLpa7O+3Ud/uZenX2vO+nZFRIYT2qDPpU9/bxO/2HX+19UXEUknBX2GrH7sFb7xq720dGo0jojkloI+g77w9A6u/eILrN2qrhwRyZ3QB/1z91zPyoXTcrb9voEYn/ruJp54eR+7j3bmrA4RmbxCH/RLppfx97dfnOsy+NufbOPGf/oln1+zjVhsfF0aWkTCLfRBP948/tI+rvz75/jaz/fQN6Ax9yKSeZMi6CtLCnJdwiCtp/r44n+9wbK/eYaXGo9zrLMn1yWJSIhl9lq+48T0iuJclzCsD37jNwD86e8u4TM3LCOSZzmuSETCZlIc0QNs/tsbc13COf3vFxpZ/Ndrue+p12g8ppO2IpI+k+KIHmDqlMJcl5CS779ykO+/cpCPX7uQ65bV8NsXTM91SSIywU2aI3qAf/3QFbkuIWWP/fpNPvp/N1B379P89NXDGpopIudt0hzRA1w4szzXJZyXP/3+ZgCWzSjjoTsuZVZlMTPG8XkHERlfJtUR/aLaMh7/2JW5LuO87Traxe0P/5qr/uF57vnBFhqPddF2qi/XZYnIODepjuiB0PR5P7X5EE9tPgTAh6+ezx9fXcfUKQVML9eRvogMNumCPoy+s/4A31l/AIDbL5vNnSvnU1texOLashxXJiLjwaQM+r+7/WL+/unt9PSH75Op/7HlMP+x5TAQPyfxkXfUsaC6lGuX1OS4MhHJFXMfX9ddqa+v94aGhoxvZ9fRTm76p19mfDvjyW2XzWbFnEpuv3wOZUX5FBdEcl2SiKSJmW109/pk8yblET1AUf6kOg8NwE+2HOYnWw7zhad3AHDp3Er+atWFzJtayuyqYvIjk2+fiEwGKQW9ma0C/hmIAN9w9weHzC8Cvg28HWgF3u/u+8ysDtgBnLmR6np3/2Saah+T2VUlrJhTydZD7bkuJWdebWrnQ8ElGAA+8c6FvH3BVKYU5XP9stocViYi6TRi142ZRYBdwI1AE7AB+IC7b09Y5lPAJe7+STO7E/gDd39/EPT/6e4pXyc4W103Z9Td+3TWtjXR1FWXcmXdND7yjjpKiyIsrJ5Cnq7FIzIujbXrZiXQ6O57gyd7ErgN2J6wzG3A54PpfwO+amYTIhFqy4t0u79h7GvtZl9rNz/a2HS27ZYVs1g+u4IZFcXcuHwGeQblxePr6qAiMlgqQT8HOJjwuAm4arhl3H3AzNqB6mDeQjPbDHQAf+Puvxpbyen19J++k5f3tvLnT27JdSkTwtNbm3l6yK0Ry4ry+bPfW0JteRELa8q4cGa5TvSKjCOZPhnbDMx391YzezvwH2b2NnfvSFzIzO4C7gKYP39+hksabHpFMTdcNCOr2wybrt4B/mHtG29p/62lNbznklkUF0R42+xKZlcVU1o4ac//i+RMKr91h4B5CY/nBm3Jlmkys3ygEmj1+AmAXgB332hme4BlwKBOeHd/FHgU4n305/FzjElpYYQ/uX4Rj/xib7Y3HWq/2n2cX+0+/pb2d6+YydWLqjHg6kXVVJcVUVaUT+EkHAklkg2pBP0GYKmZLSQe6HcCHxyyzBpgNfAycAfwgru7mdUCbe4eNbNFwFJg3KWpmXHfzRcp6LNk7dYjrN165C3tF8+poH7BNArz81h18UwqSwoojOQxb1ppDqoUCY8Rgz7oc78bWEd8eOVj7r7NzB4AGtx9DfBN4AkzawTaiP8xALgOeMDM+oEY8El3b8vED5IOa+6+lu+uP8APGg6OvLCk3euHOnj9ULxX79FfDv6jO3dqCVNLC1l9TR1TSwuIxpyrFlXj7lSVTox7DYjkyqT9ZOxw3J2F963N2fbl/Fy7pJoZFcXMrCjmdy6cTn80Rm1ZEfOmlZJnpm4hCT19MnYUzIyHP3gFn/7eplyXIqPw68bWs9P/+vM9b5lflJ/HLcGJ4YtnV3LhrHK6egZYVDuFqaWFFETy9MdAQktBn8Qtl8xi66HFfP0Xbw0MmZh6B2I8tWnoGILBpk0p5Ir5VfRFnVtWzGRWZQktnb2sXBg/bxDJM2rKirJUsUj6qOtmGLGY090f5eLPrct1KTLOlBfn4w6rr1nAtClFtJ/u53cvnE40FsPMWDajnGjUqSzVB8kke87VdaOgH8FTm5q454ev5roMmaCmlxdRVJDHopoyrl9WS19w7uCCmeV0nO5nZmUx1WVFxGI+YW5gL+OT+ujH4A+vmMvUKYXc84MtnOjuz3U5MsEcCy6vcbDtNL/Y1TLi8mVF+XT1DvDuFTNZMr2ctlO9XLO4hqrSAg6f7OGK+VUURPLoi8aYO7UEd/QpZBmRjuhHQRdAk/FqQXUpVaWFDERj3HbZbAoieXT2DHDVwml090fJzzOWTC+juy9KeVH+2SGpOgEdHuq6SZO+gRhf/8UevvzsrlyXIpI2C6pLOdndz8qF01g+q4K2U30sn13BzMpimtq6WT67ksqSAlo6e1lUO4XCSB4DMWfalEIMdEXTcUJBn0buTnN7D9c8+EKuSxEZN5bNKKOiuICTp/u5ZUV8GGtrVy/XLKmmpz9Gd1+US+ZW0tkzQH6eMW9aKd19A0wpzKesON6DXKAb34yJgj4D2k718dSmprN3axKRsaspK+J4Vy9XzK9iyfQymtt7uHpRNTMrimls6eKyeVVUlhSw7/gpVsytpCg/j2OdvSyZXgZAf9SpKSuclOcuFPQZ1H66n/d+7SUaj3XluhQRGSKSZ+TnGb0DMa5bVou74w7XLqmhpz+KA2+bXcHxrl7KivJZUD2FYx09VJcVMr28mNZTfVRPKaSiuICuvgGqSgrIjxh5ZuPuPxAFfYa5O3taTvE/f7ot6dUaRSScivLz6B2IsXxWBRUl+RxsO82Ny2dQUhjhSHsP1yyuJubOoZM9XFk3ldN9UdpP9/O22ZV09vQTdaeuegodPf0URPKYU1Vy3v+JKOiz6Oc7j/Hvmw7x01cP57oUEZlg3vW2GTzyx0mzekTnCvrx9b9HCPz2BdP5p/ddyk/vfifX6QbbIjIKLyVcsymdFPQZkB/JY8XcSr71sStZf9/v8b76ubkuSUQmgEwNVVXQZ5CZMbOymIfuuJTn7rmeJz6xkiJ9QEVEhpGpjyToEghZsmR6GUuml7HzCzfz0p7j7D7axefWbMt1WSIyjkQylPQK+hy4ZnEN1yyu4YNXzWfDm21sb+7QeHwRIc8U9KFTEMnjmiU1XLOkho9du5Dthzt4s/UUn3lyM7HxNRhKRLJAQR9ykTxjxdxKVsyt5NZLZ3O0o4eWzl4e+/Wb7DraefZeqiISXuq6mWRmVBQzo6KYL7/vMqIxpz8a4/kdxxiIxXh2+1H+87XmXJcoImmWoQN6Bf1EEMkzInkRbrlkFgC3XTaHh+4Y4FRvlB3NHWzY10ZLZy9PbjiY40pFZCzUdSODlBbmU1qYT2157dkPZj343ktoOtGNO/x343GOdvRw+ORpftjQlONqRSQV6rqRlMydWgrAB1bOP9v20B2X0trVi5nxq90tmBknu/t4/KV9dJwe4HhXb67KFZEEGkcvY1JdVgTEu33O+Mg76oD4RdnePH4KgK7eAV58o4XpFUXsbenie785QCTP6OgZyHrNIpONjuglY8yMRbVlZx9fMrfq7PT9tyzH3YnGnMMneygqyKO5vYc9x7qYNqWQbYfb2d/aTdSdZ7cfpVN/EETOm/roJWfMjPyIMb863i00o6KYy+ZVAfA7F04ftKy7E3M42d1HZ88AxQURtje3Y2YURvL41e7j1JQV0tMf5aevNlNWnM/Rjh6aTpzO9o8lMu4o6GVCMDMiFu8qOtNdNLOy+Oz8a5fUnJ2++3eXnp12d/qiMWIxONLRw5TCCN19UXYd7WR2VQmtp/rY0dxBXfUUmk50s6eli1mVJbzW1M6+1lOUFeWz5eDJrP2cIpmgrhsJNTOjKD9+w4WFNVPOttclTF+fwmWfz9xBqGcgSm9/jML8PI529OBASUGEXUc7KS/Opyg/wsb9J5hTVYIDr7zZyuLaMnoHYrzyZhuLp5fR0tnDy3taWVA9heb20+w62kVNWSHHu/rS/eOLANDR05+R51XQS6iYGWZnhp/G2xLPP8yuKjk7ffGcyrPTNy6fcXZ69TV1I25nIBrDzOgdiHK6L0pJYYT20/2c7otSWVLAkY4eojGnqqSQN1tPURAxKooL2N7cwbTSQkoKI2w+cIJ500rJM2Pj/hMsnl5GLOY07D/BBTPK6OqN0rCvjaUzyjne1ctrTSdZMr2Mwyd7ePP4KeZNK+Fgm7q8wmR/a3dGnjelO0yZ2Srgn4EI8A13f3DI/CLg28DbgVbg/e6+L5h3H/AJIAr8mbuvO9e2JvodpkRyJRpzBmIxImb0DMSIRp2igjw6TvfH/6MpjNDS2UtBXh4lhRGaTnRTXpxPcUGExmNdzKgoJj/P2HGkk7rqUmIO2w63s2xGOb39MXY0d3DRrAo6e/rZdbSLi2aV03aqjwNt3Vwws5zm9h5au3pZWFPG/tZT9A7EmFNVwq6jnRTm51FVWsDWQx1MKy0gkpfHpgMnmDu1hN7+GBsPnOCCGeW0nurlaEcvMyqKONoxOYf97nvwlvNab0y3EjSzCLALuBFoAjYAH3D37QnLfAq4xN0/aWZ3An/g7u83s+XA94GVwGzgOWCZu0eH256CXkTO5cwJ/zyDgZgzEHUKIkZ/1OkbiFFSGOF0X5T+WIyyonw6evqJxaCsOJ8Tp/owg/KiAlq6eiiMRCgtitB8soey4nxKCiLsbz1FTXkRhZE89rR0nf1syp6WLhbVTKEvGmNvyymWTC+ju2+A/a3dLJtRzsnufg63n2bp9DKOdfbS3t1PXc0UDp88Te9AlDlVpexrPUV+njGjophdRzupLCmgqrSAHc2dzKwsZkF1KdcsrhlhDyQ31qB/B/B5d39X8Pi+YGf/Y8Iy64JlXjazfOAIUAvcm7hs4nLDbU9BLyIyemO9Z+wcIPEiKk1BW9Jl3H0AaAeqU1wXM7vLzBrMrKGlpSWFkkREJFXj4r527v6ou9e7e31trW6oLSKSTqkE/SFgXsLjuUFb0mWCrptK4idlU1lXREQyKJWg3wAsNbOFZlYI3AmsGbLMGmB1MH0H8ILHO//XAHeaWZGZLQSWAq+kp3QREUnFiOPo3X3AzO4G1hEfXvmYu28zsweABndfA3wTeMLMGoE24n8MCJb7IbAdGAA+fa4RNyIikn4pjaPPJo26EREZvbGOuhERkQlMQS8iEnLjruvGzFqA/WN4ihrgeJrKSSfVNTqqa3RU1+iEsa4F7p50fPq4C/qxMrOG4fqpckl1jY7qGh3VNTqTrS513YiIhJyCXkQk5MIY9I/muoBhqK7RUV2jo7pGZ1LVFbo+ehERGSyMR/QiIpJAQS8iEnKhCXozW2VmO82s0czuzfK255nZi2a23cy2mdmfB+2fN7NDZrYl+Hp3wjr3BbXuNLN3ZbC2fWa2Ndh+Q9A2zcyeNbPdwfepQbuZ2b8Edb1mZldkqKYLEvbJFjPrMLPP5GJ/mdljZnbMzF5PaBv1/jGz1cHyu81sdbJtpaGuL5nZG8G2f2xmVUF7nZmdTthvX09Y5+3B698Y1G4Zqm3Ur126f2eHqesHCTXtM7MtQXtW9tk5siG77zF3n/BfxC+2tgdYBBQCrwLLs7j9WcAVwXQ58VsvLgc+D/xlkuWXBzUWAQuD2iMZqm0fUDOk7SHg3mD6XuCLwfS7gWcAA64GfpOl1+4IsCAX+wu4DrgCeP189w8wDdgbfJ8aTE/NQF03AfnB9BcT6qpLXG7I87wS1GpB7TdnaJ+N6rXLxO9ssrqGzP9fwGezuc/OkQ1ZfY+F5Yh+JdDo7nvdvQ94ErgtWxt392Z33xRMdwI7SHInrQS3AU+6e6+7vwk0Ev8ZsuU24FvB9LeA2xPav+1x64EqM5uV4Vp+D9jj7uf6NHTG9pe7/5L4FVeHbm80++ddwLPu3ubuJ4BngVXprsvdf+bxO7gBrCd+f4dhBbVVuPt6j6fFtxN+lrTWdg7DvXZp/509V13BUfn7iN/Deljp3mfnyIasvsfCEvQp3bIwG8ysDrgc+E3QdHfwL9hjZ/49I7v1OvAzM9toZncFbTPcvTmYPgLMyEFdZ9zJ4F++XO8vGP3+ycV++zjxI78zFprZZjP7hZn9VtA2J6glW3WN5rXL9j77LeCou+9OaMvqPhuSDVl9j4Ul6McFMysD/h34jLt3AF8DFgOXAc3E/3XMtne6+xXAzcCnzey6xJnBUUtOxtha/EY2twI/CprGw/4aJJf7Zzhmdj/x+zt8N2hqBua7++XAPcD3zKwiy2WNu9duiA8w+IAiq/ssSTaclY33WFiCPue3LDSzAuIv5Hfd/SkAdz/q7lF3jwH/h//f3ZC1et39UPD9GPDjoIajZ7pkgu/Hsl1X4GZgk7sfDWrM+f4KjHb/ZK0+M/so8B7gQ0FAEHSLtAbTG4n3fS8Lakjs3snk+2y0r10291k+8IfADxLqzdo+S5YNZPk9FpagT+V2hxkT9P99E9jh7l9OaE/s3/4D4MxogKzcYtHMpphZ+Zlp4ifzXmfwrR9XAz9JqOsjwZn/q4H2hH8vM2HQUVau91eC0e6fdcBNZjY16LK4KWhLKzNbBfwVcKu7dye015pZJJheRHz/7A1q6zCzq4P36EcSfpZ01zba1y6bv7M3AG+4+9kumWzts+GygWy/x873bPJ4+yJ+tnoX8b/M92d52+8k/q/Xa8CW4OvdwBPA1qB9DTArYZ37g1p3koaREMPUtYj4aIZXgW1n9gtQDTwP7AaeA6YF7QY8HNS1FajP4D6bQvwG8pUJbVnfX8T/0DQD/cT7PT9xPvuHeJ95Y/D1sQzV1Ui8n/bMe+zrwbLvDV7fLcAm4PcTnqeeeOjuAb5K8Gn4DNQ26tcu3b+zyeoK2h8HPjlk2azsM4bPhqy+x3QJBBGRkAtL142IiAxDQS8iEnIKehGRkFPQi4iEnIJeRCTkFPQiIiGnoBcRCbn/BwO9k/UF1xiuAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(lossi)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }