{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from utils import Scalar, draw_graph\n", "from linear_model import Linear, mse" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "140621852371504forward\n", "\n", "grad=None\n", "\n", "value=0.00\n", "\n", "+\n", "\n", "\n", "140621852371600forward\n", "\n", "grad=None\n", "\n", "value=1.00\n", "\n", "-\n", "\n", "\n", "140621852371504forward->140621852371600forward\n", "\n", "\n", "\n", "\n", "140621852371072forward\n", "\n", "x1=1.50\n", "\n", "\n", "140621852371120forward\n", "\n", "grad=None\n", "\n", "value=0.00\n", "\n", "*\n", "\n", "\n", "140621852371072forward->140621852371120forward\n", "\n", "\n", "\n", "\n", "140621852371792forward\n", "\n", "grad=None\n", "\n", "value=8.50\n", "\n", "mse\n", "\n", "\n", "140621852371600forward->140621852371792forward\n", "\n", "\n", "\n", "\n", "140621852371120forward->140621852371504forward\n", "\n", "\n", "\n", "\n", "140621852371648forward\n", "\n", "grad=None\n", "\n", "value=0.00\n", "\n", "*\n", "\n", "\n", "140621852371696forward\n", "\n", "grad=None\n", "\n", "value=0.00\n", "\n", "+\n", "\n", "\n", "140621852371648forward->140621852371696forward\n", "\n", "\n", "\n", "\n", "140621852371168forward\n", "\n", "grad=None\n", "\n", "value=0.00\n", "\n", "a\n", "\n", "\n", "140621852371168forward->140621852371120forward\n", "\n", "\n", "\n", "\n", "140621852371168forward->140621852371648forward\n", "\n", "\n", "\n", "\n", "140621852371744forward\n", "\n", "grad=None\n", "\n", "value=4.00\n", "\n", "-\n", "\n", "\n", "140621852371696forward->140621852371744forward\n", "\n", "\n", "\n", "\n", "140621852371744forward->140621852371792forward\n", "\n", "\n", "\n", "\n", "140621852371264forward\n", "\n", "grad=None\n", "\n", "value=0.00\n", "\n", "b\n", "\n", "\n", "140621852371264forward->140621852371504forward\n", "\n", "\n", "\n", "\n", "140621852371264forward->140621852371696forward\n", "\n", "\n", "\n", "\n", "140621852371312forward\n", "\n", "x2=2.00\n", "\n", "\n", "140621852371312forward->140621852371648forward\n", "\n", "\n", "\n", "\n", "140621852371360forward\n", "\n", "y1=1.00\n", "\n", "\n", "140621852371360forward->140621852371600forward\n", "\n", "\n", "\n", "\n", "140621852371408forward\n", "\n", "y2=4.00\n", "\n", "\n", "140621852371408forward->140621852371744forward\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 计算图膨胀\n", "model = Linear()\n", "# 定义训练数据\n", "x1 = Scalar(1.5, label='x1', requires_grad=False)\n", "y1 = Scalar(1.0, label='y1', requires_grad=False)\n", "x2 = Scalar(2.0, label='x2', requires_grad=False)\n", "y2 = Scalar(4.0, label='y2', requires_grad=False)\n", "\n", "loss = mse([model.error(x1, y1), model.error(x2, y2)])\n", "draw_graph(loss)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "140621852371504backward\n", "\n", "grad=-1.00\n", "\n", "value=0.00\n", "\n", "+\n", "\n", "\n", "140621852371120backward\n", "\n", "grad=-1.00\n", "\n", "value=0.00\n", "\n", "*\n", "\n", "\n", "140621852371504backward->140621852371120backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "140621852371264backward\n", "\n", "grad=-5.00\n", "\n", "value=0.00\n", "\n", "b\n", "\n", "\n", "140621852371504backward->140621852371264backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "140621852371072backward\n", "\n", "x1=1.50\n", "\n", "\n", "140621852371600backward\n", "\n", "grad=1.00\n", "\n", "value=1.00\n", "\n", "-\n", "\n", "\n", "140621852371600backward->140621852371504backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "140621852371360backward\n", "\n", "y1=1.00\n", "\n", "\n", "140621852371600backward->140621852371360backward\n", "\n", "\n", "\n", "140621852371120backward->140621852371072backward\n", "\n", "\n", "\n", "140621852371168backward\n", "\n", "grad=-9.50\n", "\n", "value=0.00\n", "\n", "a\n", "\n", "\n", "140621852371120backward->140621852371168backward\n", "\n", "\n", "-1.50\n", "\n", "\n", "140621852371648backward\n", "\n", "grad=-4.00\n", "\n", "value=0.00\n", "\n", "*\n", "\n", "\n", "140621852371648backward->140621852371168backward\n", "\n", "\n", "-8.00\n", "\n", "\n", "140621852371312backward\n", "\n", "x2=2.00\n", "\n", "\n", "140621852371648backward->140621852371312backward\n", "\n", "\n", "\n", "140621852371696backward\n", "\n", "grad=-4.00\n", "\n", "value=0.00\n", "\n", "+\n", "\n", "\n", "140621852371696backward->140621852371648backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "140621852371696backward->140621852371264backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "140621852371744backward\n", "\n", "grad=4.00\n", "\n", "value=4.00\n", "\n", "-\n", "\n", "\n", "140621852371744backward->140621852371696backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "140621852371408backward\n", "\n", "y2=4.00\n", "\n", "\n", "140621852371744backward->140621852371408backward\n", "\n", "\n", "\n", "140621852371792backward\n", "\n", "grad=1.00\n", "\n", "value=8.50\n", "\n", "mse\n", "\n", "\n", "140621852371792backward->140621852371600backward\n", "\n", "\n", "1.00\n", "\n", "\n", "140621852371792backward->140621852371744backward\n", "\n", "\n", "4.00\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 第一次触发方向传播\n", "loss.backward()\n", "draw_graph(loss, 'backward')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "140621852371504backward\n", "\n", "grad=-2.00\n", "\n", "value=0.00\n", "\n", "+\n", "\n", "\n", "140621852371120backward\n", "\n", "grad=-2.00\n", "\n", "value=0.00\n", "\n", "*\n", "\n", "\n", "140621852371504backward->140621852371120backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "140621852371264backward\n", "\n", "grad=-10.00\n", "\n", "value=0.00\n", "\n", "b\n", "\n", "\n", "140621852371504backward->140621852371264backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "140621852371072backward\n", "\n", "x1=1.50\n", "\n", "\n", "140621852371600backward\n", "\n", "grad=2.00\n", "\n", "value=1.00\n", "\n", "-\n", "\n", "\n", "140621852371600backward->140621852371504backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "140621852371360backward\n", "\n", "y1=1.00\n", "\n", "\n", "140621852371600backward->140621852371360backward\n", "\n", "\n", "\n", "140621852371120backward->140621852371072backward\n", "\n", "\n", "\n", "140621852371168backward\n", "\n", "grad=-19.00\n", "\n", "value=0.00\n", "\n", "a\n", "\n", "\n", "140621852371120backward->140621852371168backward\n", "\n", "\n", "-1.50\n", "\n", "\n", "140621852371648backward\n", "\n", "grad=-8.00\n", "\n", "value=0.00\n", "\n", "*\n", "\n", "\n", "140621852371648backward->140621852371168backward\n", "\n", "\n", "-8.00\n", "\n", "\n", "140621852371312backward\n", "\n", "x2=2.00\n", "\n", "\n", "140621852371648backward->140621852371312backward\n", "\n", "\n", "\n", "140621852371696backward\n", "\n", "grad=-8.00\n", "\n", "value=0.00\n", "\n", "+\n", "\n", "\n", "140621852371696backward->140621852371648backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "140621852371696backward->140621852371264backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "140621852371744backward\n", "\n", "grad=8.00\n", "\n", "value=4.00\n", "\n", "-\n", "\n", "\n", "140621852371744backward->140621852371696backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "140621852371408backward\n", "\n", "y2=4.00\n", "\n", "\n", "140621852371744backward->140621852371408backward\n", "\n", "\n", "\n", "140621852371792backward\n", "\n", "grad=2.00\n", "\n", "value=8.50\n", "\n", "mse\n", "\n", "\n", "140621852371792backward->140621852371600backward\n", "\n", "\n", "1.00\n", "\n", "\n", "140621852371792backward->140621852371744backward\n", "\n", "\n", "4.00\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 第二次触发方向传播\n", "loss.backward()\n", "draw_graph(loss, 'backward')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "# 固定随机种子,使得运行结果可以稳定复现\n", "torch.manual_seed(1024)\n", "# 产生训练用的数据\n", "x_origin = torch.linspace(100, 300, 200)\n", "# 将变量X归一化,否则梯度下降法很容易不稳定\n", "x = (x_origin - torch.mean(x_origin)) / torch.std(x_origin)\n", "epsilon = torch.randn(x.shape)\n", "y = 10 * x + 5 + epsilon" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step 1, Result: y = 3.12 * x + -1.99\n", "Step 2, Result: y = 3.48 * x + -2.28\n", "Step 3, Result: y = 3.22 * x + -1.97\n", "Step 4, Result: y = 2.85 * x + -1.22\n", "Step 5, Result: y = 2.68 * x + -0.23\n", "Step 6, Result: y = 2.92 * x + 1.08\n", "Step 7, Result: y = 3.74 * x + 2.61\n", "Step 8, Result: y = 5.07 * x + 4.15\n", "Step 9, Result: y = 6.73 * x + 5.52\n", "Step 10, Result: y = 8.22 * x + 6.48\n", "Step 11, Result: y = 9.36 * x + 5.75\n", "Step 12, Result: y = 9.75 * x + 5.42\n", "Step 13, Result: y = 9.88 * x + 5.28\n", "Step 14, Result: y = 9.89 * x + 5.26\n", "Step 15, Result: y = 9.89 * x + 5.20\n", "Step 16, Result: y = 9.88 * x + 5.18\n", "Step 17, Result: y = 9.88 * x + 5.17\n", "Step 18, Result: y = 9.84 * x + 5.14\n", "Step 19, Result: y = 9.86 * x + 5.15\n", "Step 20, Result: y = 9.94 * x + 5.21\n" ] } ], "source": [ "# 生成模型\n", "model = Linear()\n", "# 定义每批次用到的数据量\n", "batch_size = 20\n", "learning_rate = 0.1\n", "\n", "for t in range(20):\n", " # 选取当前批次的数据,用于训练模型\n", " ix = (t * batch_size) % len(x)\n", " xx = x[ix: ix + batch_size]\n", " yy = y[ix: ix + batch_size]\n", " # 计算当前批次数据的损失\n", " loss = mse([model.error(_x, _y) for _x, _y in zip(xx, yy)])\n", " # 计算损失函数的梯度\n", " loss.backward()\n", " # 迭代更新模型参数的估计值\n", " model.a -= learning_rate * model.a.grad\n", " model.b -= learning_rate * model.b.grad\n", " # 将使用完的梯度清零\n", " model.a.grad = 0.0\n", " model.b.grad = 0.0\n", " print(f'Step {t + 1}, Result: {model.string()}')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }