{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from utils import Scalar, draw_graph\n", "from linear_model import Linear, mse" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "4520854528forward\n", "\n", "grad=None\n", "\n", "value= 0.00\n", "\n", "*\n", "\n", "\n", "4520854576forward\n", "\n", "grad=None\n", "\n", "value= 0.00\n", "\n", "+\n", "\n", "\n", "4520854528forward->4520854576forward\n", "\n", "\n", "\n", "\n", "4520854624forward\n", "\n", "grad=None\n", "\n", "value= 4.00\n", "\n", "-\n", "\n", "\n", "4520854576forward->4520854624forward\n", "\n", "\n", "\n", "\n", "4520854096forward\n", "\n", "grad=None\n", "\n", "value= 0.00\n", "\n", "a\n", "\n", "\n", "4520854096forward->4520854528forward\n", "\n", "\n", "\n", "\n", "4520853952forward\n", "\n", "grad=None\n", "\n", "value= 0.00\n", "\n", "*\n", "\n", "\n", "4520854096forward->4520853952forward\n", "\n", "\n", "\n", "\n", "4520854672forward\n", "\n", "grad=None\n", "\n", "value= 8.50\n", "\n", "mse\n", "\n", "\n", "4520854624forward->4520854672forward\n", "\n", "\n", "\n", "\n", "4520854144forward\n", "\n", "grad=None\n", "\n", "value= 0.00\n", "\n", "b\n", "\n", "\n", "4520854144forward->4520854576forward\n", "\n", "\n", "\n", "\n", "4520854384forward\n", "\n", "grad=None\n", "\n", "value= 0.00\n", "\n", "+\n", "\n", "\n", "4520854144forward->4520854384forward\n", "\n", "\n", "\n", "\n", "4520854192forward\n", "\n", "x2= 2.00\n", "\n", "\n", "4520854192forward->4520854528forward\n", "\n", "\n", "\n", "\n", "4520854240forward\n", "\n", "y1= 1.00\n", "\n", "\n", "4520854480forward\n", "\n", "grad=None\n", "\n", "value= 1.00\n", "\n", "-\n", "\n", "\n", "4520854240forward->4520854480forward\n", "\n", "\n", "\n", "\n", "4520854288forward\n", "\n", "y2= 4.00\n", "\n", "\n", "4520854288forward->4520854624forward\n", "\n", "\n", "\n", "\n", "4520853856forward\n", "\n", "x1= 1.50\n", "\n", "\n", "4520853856forward->4520853952forward\n", "\n", "\n", "\n", "\n", "4520854384forward->4520854480forward\n", "\n", "\n", "\n", "\n", "4520853952forward->4520854384forward\n", "\n", "\n", "\n", "\n", "4520854480forward->4520854672forward\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 计算图膨胀\n", "model = Linear()\n", "x1 = Scalar(1.5, label='x1', requires_grad=False)\n", "y1 = Scalar(1.0, label='y1', requires_grad=False)\n", "x2 = Scalar(2.0, label='x2', requires_grad=False)\n", "y2 = Scalar(4.0, label='y2', requires_grad=False)\n", "\n", "loss = mse([model.error(x1, y1), model.error(x2, y2)])\n", "draw_graph(loss)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "4520854528backward\n", "\n", "grad=-4.00\n", "\n", "value= 0.00\n", "\n", "*\n", "\n", "\n", "4520854096backward\n", "\n", "grad=-9.50\n", "\n", "value= 0.00\n", "\n", "a\n", "\n", "\n", "4520854528backward->4520854096backward\n", "\n", "\n", "-8.00\n", "\n", "\n", "4520854192backward\n", "\n", "x2= 2.00\n", "\n", "\n", "4520854528backward->4520854192backward\n", "\n", "\n", "\n", "4520854576backward\n", "\n", "grad=-4.00\n", "\n", "value= 0.00\n", "\n", "+\n", "\n", "\n", "4520854576backward->4520854528backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "4520854144backward\n", "\n", "grad=-5.00\n", "\n", "value= 0.00\n", "\n", "b\n", "\n", "\n", "4520854576backward->4520854144backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "4520854624backward\n", "\n", "grad= 4.00\n", "\n", "value= 4.00\n", "\n", "-\n", "\n", "\n", "4520854624backward->4520854576backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "4520854288backward\n", "\n", "y2= 4.00\n", "\n", "\n", "4520854624backward->4520854288backward\n", "\n", "\n", "\n", "4520854672backward\n", "\n", "grad= 1.00\n", "\n", "value= 8.50\n", "\n", "mse\n", "\n", "\n", "4520854672backward->4520854624backward\n", "\n", "\n", " 4.00\n", "\n", "\n", "4520854480backward\n", "\n", "grad= 1.00\n", "\n", "value= 1.00\n", "\n", "-\n", "\n", "\n", "4520854672backward->4520854480backward\n", "\n", "\n", " 1.00\n", "\n", "\n", "4520854240backward\n", "\n", "y1= 1.00\n", "\n", "\n", "4520853856backward\n", "\n", "x1= 1.50\n", "\n", "\n", "4520854384backward\n", "\n", "grad=-1.00\n", "\n", "value= 0.00\n", "\n", "+\n", "\n", "\n", "4520854384backward->4520854144backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "4520853952backward\n", "\n", "grad=-1.00\n", "\n", "value= 0.00\n", "\n", "*\n", "\n", "\n", "4520854384backward->4520853952backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "4520853952backward->4520854096backward\n", "\n", "\n", "-1.50\n", "\n", "\n", "4520853952backward->4520853856backward\n", "\n", "\n", "\n", "4520854480backward->4520854240backward\n", "\n", "\n", "\n", "4520854480backward->4520854384backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 第一次触发方向传播\n", "loss.backward()\n", "draw_graph(loss, 'backward')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "4520854528backward\n", "\n", "grad=-8.00\n", "\n", "value= 0.00\n", "\n", "*\n", "\n", "\n", "4520854096backward\n", "\n", "grad=-19.00\n", "\n", "value= 0.00\n", "\n", "a\n", "\n", "\n", "4520854528backward->4520854096backward\n", "\n", "\n", "-8.00\n", "\n", "\n", "4520854192backward\n", "\n", "x2= 2.00\n", "\n", "\n", "4520854528backward->4520854192backward\n", "\n", "\n", "\n", "4520854576backward\n", "\n", "grad=-8.00\n", "\n", "value= 0.00\n", "\n", "+\n", "\n", "\n", "4520854576backward->4520854528backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "4520854144backward\n", "\n", "grad=-10.00\n", "\n", "value= 0.00\n", "\n", "b\n", "\n", "\n", "4520854576backward->4520854144backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "4520854624backward\n", "\n", "grad= 8.00\n", "\n", "value= 4.00\n", "\n", "-\n", "\n", "\n", "4520854624backward->4520854576backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "4520854288backward\n", "\n", "y2= 4.00\n", "\n", "\n", "4520854624backward->4520854288backward\n", "\n", "\n", "\n", "4520854672backward\n", "\n", "grad= 2.00\n", "\n", "value= 8.50\n", "\n", "mse\n", "\n", "\n", "4520854672backward->4520854624backward\n", "\n", "\n", " 4.00\n", "\n", "\n", "4520854480backward\n", "\n", "grad= 2.00\n", "\n", "value= 1.00\n", "\n", "-\n", "\n", "\n", "4520854672backward->4520854480backward\n", "\n", "\n", " 1.00\n", "\n", "\n", "4520854240backward\n", "\n", "y1= 1.00\n", "\n", "\n", "4520853856backward\n", "\n", "x1= 1.50\n", "\n", "\n", "4520854384backward\n", "\n", "grad=-2.00\n", "\n", "value= 0.00\n", "\n", "+\n", "\n", "\n", "4520854384backward->4520854144backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "4520853952backward\n", "\n", "grad=-2.00\n", "\n", "value= 0.00\n", "\n", "*\n", "\n", "\n", "4520854384backward->4520853952backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "4520853952backward->4520854096backward\n", "\n", "\n", "-1.50\n", "\n", "\n", "4520853952backward->4520853856backward\n", "\n", "\n", "\n", "4520854480backward->4520854240backward\n", "\n", "\n", "\n", "4520854480backward->4520854384backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 第二次触发方向传播\n", "loss.backward()\n", "draw_graph(loss, 'backward')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "# 固定随机种子,使得运行结果可以稳定复现\n", "torch.manual_seed(1024)\n", "# 产生训练用的数据\n", "x_origin = torch.linspace(100, 300, 200)\n", "# 将变量X归一化,否则梯度下降法很容易不稳定\n", "x = (x_origin - torch.mean(x_origin)) / torch.std(x_origin)\n", "epsilon = torch.randn(x.shape)\n", "y = 10 * x + 5 + epsilon" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step 1, Result: y = 3.12 * x + -1.99\n", "Step 2, Result: y = 3.48 * x + -2.28\n", "Step 3, Result: y = 3.22 * x + -1.97\n", "Step 4, Result: y = 2.85 * x + -1.22\n", "Step 5, Result: y = 2.68 * x + -0.23\n", "Step 6, Result: y = 2.92 * x + 1.08\n", "Step 7, Result: y = 3.74 * x + 2.61\n", "Step 8, Result: y = 5.07 * x + 4.15\n", "Step 9, Result: y = 6.73 * x + 5.52\n", "Step 10, Result: y = 8.22 * x + 6.48\n", "Step 11, Result: y = 9.36 * x + 5.75\n", "Step 12, Result: y = 9.75 * x + 5.42\n", "Step 13, Result: y = 9.88 * x + 5.28\n", "Step 14, Result: y = 9.89 * x + 5.26\n", "Step 15, Result: y = 9.89 * x + 5.20\n", "Step 16, Result: y = 9.88 * x + 5.18\n", "Step 17, Result: y = 9.88 * x + 5.17\n", "Step 18, Result: y = 9.84 * x + 5.14\n", "Step 19, Result: y = 9.86 * x + 5.15\n", "Step 20, Result: y = 9.94 * x + 5.21\n" ] } ], "source": [ "# 生成模型\n", "model = Linear()\n", "# 定义每批次用到的数据量\n", "batch_size = 20\n", "learning_rate = 0.1\n", "\n", "for t in range(20):\n", " # 选取当前批次的数据,用于训练模型\n", " ix = (t * batch_size) % len(x)\n", " xx = x[ix: ix + batch_size]\n", " yy = y[ix: ix + batch_size]\n", " # 计算当前批次数据的损失\n", " loss = mse([model.error(_x, _y) for _x, _y in zip(xx, yy)])\n", " # 计算损失函数的梯度\n", " loss.backward()\n", " # 迭代更新模型参数的估计值\n", " model.a -= learning_rate * model.a.grad\n", " model.b -= learning_rate * model.b.grad\n", " # 将使用完的梯度清零\n", " model.a.grad = 0.0\n", " model.b.grad = 0.0\n", " print(f'Step {t + 1}, Result: {model.string()}')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }