{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import sys\n", "import os\n", "sys.path.append(os.path.abspath(os.path.join('../ch07_autograd')))\n", "from utils import Scalar, draw_graph\n", "from linear_model import Linear, mse" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "\n", "torch.manual_seed(1024)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "x = torch.linspace(100, 300, 200)\n", "x = (x - torch.mean(x)) / torch.std(x)\n", "epsilon = torch.randn(x.shape)\n", "y = 10 * x + 5 + epsilon" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y = 3.12 * x + -1.99\n", "y = 3.48 * x + -2.28\n", "y = 3.22 * x + -1.97\n", "y = 2.85 * x + -1.22\n", "y = 2.68 * x + -0.23\n", "y = 2.92 * x + 1.08\n", "y = 3.74 * x + 2.61\n", "y = 5.07 * x + 4.15\n", "y = 6.73 * x + 5.52\n", "y = 8.22 * x + 6.48\n", "y = 9.36 * x + 5.75\n", "y = 9.75 * x + 5.42\n", "y = 9.88 * x + 5.28\n", "y = 9.89 * x + 5.26\n", "y = 9.89 * x + 5.20\n", "y = 9.88 * x + 5.18\n", "y = 9.88 * x + 5.17\n", "y = 9.84 * x + 5.14\n", "y = 9.86 * x + 5.15\n", "y = 9.94 * x + 5.21\n" ] } ], "source": [ "model = Linear()\n", "\n", "batch_size = 20\n", "learning_rate = 0.1\n", "\n", "for t in range(20):\n", " ix = (t * batch_size) % len(x)\n", " xx = x[ix: ix + batch_size]\n", " yy = y[ix: ix + batch_size]\n", " loss = mse([model.error(_x, _y) for _x, _y in zip(xx, yy)])\n", " loss.backward()\n", " model.a -= learning_rate * model.a.grad\n", " model.b -= learning_rate * model.b.grad\n", " model.a.grad = 0.0\n", " model.b.grad = 0.0\n", " print(model.string())" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "140555659725312backward\n", "\n", "grad=-9.50\n", "\n", "value=0.00\n", "\n", "a\n", "\n", "\n", "140555659723296backward\n", "\n", "grad=-1.00\n", "\n", "value=0.00\n", "\n", "+\n", "\n", "\n", "140555659723392backward\n", "\n", "grad=-5.00\n", "\n", "value=0.00\n", "\n", "b\n", "\n", "\n", "140555659723296backward->140555659723392backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "140555659725264backward\n", "\n", "grad=-1.00\n", "\n", "value=0.00\n", "\n", "*\n", "\n", "\n", "140555659723296backward->140555659725264backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "140555659722816backward\n", "\n", "x1=1.50\n", "\n", "\n", "140555659725408backward\n", "\n", "grad=4.00\n", "\n", "value=4.00\n", "\n", "-\n", "\n", "\n", "140555659725504backward\n", "\n", "grad=-4.00\n", "\n", "value=0.00\n", "\n", "+\n", "\n", "\n", "140555659725408backward->140555659725504backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "140555659725600backward\n", "\n", "y2=4.00\n", "\n", "\n", "140555659725408backward->140555659725600backward\n", "\n", "\n", "\n", "140555659725456backward\n", "\n", "grad=-4.00\n", "\n", "value=0.00\n", "\n", "*\n", "\n", "\n", "140555659725456backward->140555659725312backward\n", "\n", "\n", "-8.00\n", "\n", "\n", "140555659722960backward\n", "\n", "x2=2.00\n", "\n", "\n", "140555659725456backward->140555659722960backward\n", "\n", "\n", "\n", "140555659723440backward\n", "\n", "y1=1.00\n", "\n", "\n", "140555659725504backward->140555659723392backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "140555659725504backward->140555659725456backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "140555659726128backward\n", "\n", "grad=1.00\n", "\n", "value=8.50\n", "\n", "mse\n", "\n", "\n", "140555659726128backward->140555659725408backward\n", "\n", "\n", "4.00\n", "\n", "\n", "140555659725792backward\n", "\n", "grad=1.00\n", "\n", "value=1.00\n", "\n", "-\n", "\n", "\n", "140555659726128backward->140555659725792backward\n", "\n", "\n", "1.00\n", "\n", "\n", "140555659725264backward->140555659725312backward\n", "\n", "\n", "-1.50\n", "\n", "\n", "140555659725264backward->140555659722816backward\n", "\n", "\n", "\n", "140555659725792backward->140555659723296backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "140555659725792backward->140555659723440backward\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 计算图膨胀\n", "model = Linear()\n", "# 定义两组数据\n", "x1 = Scalar(1.5, label='x1', requires_grad=False)\n", "y1 = Scalar(1.0, label='y1', requires_grad=False)\n", "x2 = Scalar(2.0, label='x2', requires_grad=False)\n", "y2 = Scalar(4.0, label='y2', requires_grad=False)\n", "loss = mse([model.error(x1, y1), model.error(x2, y2)])\n", "loss.backward()\n", "draw_graph(loss, 'backward')" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "140555658644016backward\n", "\n", "input=0.50\n", "\n", "\n", "140555659722816backward\n", "\n", "x1=1.50\n", "\n", "\n", "140555664134784backward\n", "\n", "grad=-1.00\n", "\n", "value=0.00\n", "\n", "+\n", "\n", "\n", "140555664137856backward\n", "\n", "grad=-1.00\n", "\n", "value=0.00\n", "\n", "*\n", "\n", "\n", "140555664134784backward->140555664137856backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "140555664138096backward\n", "\n", "grad=-1.00\n", "\n", "value=0.00\n", "\n", "b\n", "\n", "\n", "140555664134784backward->140555664138096backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "140555658644112backward\n", "\n", "grad=1.00\n", "\n", "value=0.50\n", "\n", "*\n", "\n", "\n", "140555658644112backward->140555658644016backward\n", "\n", "\n", "\n", "140555658644256backward\n", "\n", "grad=0.50\n", "\n", "value=1.00\n", "\n", "mse\n", "\n", "\n", "140555658644112backward->140555658644256backward\n", "\n", "\n", "0.50\n", "\n", "\n", "140555664134304backward\n", "\n", "grad=1.00\n", "\n", "value=1.00\n", "\n", "-\n", "\n", "\n", "140555664134304backward->140555664134784backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "140555659723440backward\n", "\n", "y1=1.00\n", "\n", "\n", "140555664134304backward->140555659723440backward\n", "\n", "\n", "\n", "140555664137856backward->140555659722816backward\n", "\n", "\n", "\n", "140555664137568backward\n", "\n", "grad=-1.50\n", "\n", "value=0.00\n", "\n", "a\n", "\n", "\n", "140555664137856backward->140555664137568backward\n", "\n", "\n", "-1.50\n", "\n", "\n", "140555658644256backward->140555664134304backward\n", "\n", "\n", "1.00\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 第一次传播\n", "model = Linear()\n", "loss = 0.5 * mse([model.error(x1, y1)])\n", "loss.backward()\n", "draw_graph(loss, 'backward')" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "140555664028672backward\n", "\n", "grad=-4.00\n", "\n", "value=0.00\n", "\n", "*\n", "\n", "\n", "140555659722960backward\n", "\n", "x2=2.00\n", "\n", "\n", "140555664028672backward->140555659722960backward\n", "\n", "\n", "\n", "140555664137568backward\n", "\n", "grad=-9.50\n", "\n", "value=0.00\n", "\n", "a\n", "\n", "\n", "140555664028672backward->140555664137568backward\n", "\n", "\n", "-8.00\n", "\n", "\n", "140555664028192backward\n", "\n", "input=0.50\n", "\n", "\n", "140555664028768backward\n", "\n", "grad=4.00\n", "\n", "value=4.00\n", "\n", "-\n", "\n", "\n", "140555664028864backward\n", "\n", "grad=-4.00\n", "\n", "value=0.00\n", "\n", "+\n", "\n", "\n", "140555664028768backward->140555664028864backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "140555659725600backward\n", "\n", "y2=4.00\n", "\n", "\n", "140555664028768backward->140555659725600backward\n", "\n", "\n", "\n", "140555664029824backward\n", "\n", "grad=1.00\n", "\n", "value=8.00\n", "\n", "*\n", "\n", "\n", "140555664029824backward->140555664028192backward\n", "\n", "\n", "\n", "140555664029152backward\n", "\n", "grad=0.50\n", "\n", "value=16.00\n", "\n", "mse\n", "\n", "\n", "140555664029824backward->140555664029152backward\n", "\n", "\n", "0.50\n", "\n", "\n", "140555664028864backward->140555664028672backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "140555664138096backward\n", "\n", "grad=-5.00\n", "\n", "value=0.00\n", "\n", "b\n", "\n", "\n", "140555664028864backward->140555664138096backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "140555664029152backward->140555664028768backward\n", "\n", "\n", "4.00\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 第二次传播(梯度积累)\n", "loss = 0.5 * mse([model.error(x2, y2)])\n", "loss.backward()\n", "draw_graph(loss, 'backward')" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y = 3.12 * x + -1.99\n", "y = 3.48 * x + -2.28\n", "y = 3.22 * x + -1.97\n", "y = 2.85 * x + -1.22\n", "y = 2.68 * x + -0.23\n", "y = 2.92 * x + 1.08\n", "y = 3.74 * x + 2.61\n", "y = 5.07 * x + 4.15\n", "y = 6.73 * x + 5.52\n", "y = 8.22 * x + 6.48\n", "y = 9.36 * x + 5.75\n", "y = 9.75 * x + 5.42\n", "y = 9.88 * x + 5.28\n", "y = 9.89 * x + 5.26\n", "y = 9.89 * x + 5.20\n", "y = 9.88 * x + 5.18\n", "y = 9.88 * x + 5.17\n", "y = 9.84 * x + 5.14\n", "y = 9.86 * x + 5.15\n", "y = 9.94 * x + 5.21\n" ] } ], "source": [ "model = Linear()\n", "\n", "batch_size = 20\n", "learning_rate = 0.1\n", "# 梯度积累次数\n", "gradient_accu_iter = 4\n", "# 小批量数据量\n", "micro_size = int(batch_size / gradient_accu_iter)\n", "\n", "\n", "for t in range(20 * gradient_accu_iter):\n", " ix = (t * micro_size) % len(x)\n", " xx = x[ix: ix + micro_size]\n", " yy = y[ix: ix + micro_size]\n", " loss = mse([model.error(_x, _y) for _x, _y in zip(xx, yy)])\n", " # 调整权重\n", " loss *= 1 / gradient_accu_iter\n", " loss.backward()\n", " if (t + 1) % gradient_accu_iter == 0:\n", " model.a -= learning_rate * model.a.grad\n", " model.b -= learning_rate * model.b.grad\n", " model.a.grad = 0.0\n", " model.b.grad = 0.0\n", " print(model.string())\n", " #model.a -= learning_rate * model.a.grad\n", " #model.b -= learning_rate * model.b.grad\n", " #model.a.grad = 0.0\n", " #model.b.grad = 0.0\n", " #print(model.string())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }