{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from utils import Scalar, draw_graph\n", "from linear_model import Linear, mse" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "4576520272backward\n", "\n", "grad=-5.00\n", "\n", "value= 0.00\n", "\n", "b\n", "\n", "\n", "4576520320backward\n", "\n", "grad=-1.00\n", "\n", "value= 0.00\n", "\n", "+\n", "\n", "\n", "4576520320backward->4576520272backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "4576520080backward\n", "\n", "grad=-1.00\n", "\n", "value= 0.00\n", "\n", "*\n", "\n", "\n", "4576520320backward->4576520080backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "4576519840backward\n", "\n", "y1= 1.00\n", "\n", "\n", "4576519888backward\n", "\n", "y2= 4.00\n", "\n", "\n", "4576520416backward\n", "\n", "grad= 1.00\n", "\n", "value= 8.50\n", "\n", "mse\n", "\n", "\n", "4576520464backward\n", "\n", "grad= 4.00\n", "\n", "value= 4.00\n", "\n", "-\n", "\n", "\n", "4576520416backward->4576520464backward\n", "\n", "\n", " 4.00\n", "\n", "\n", "4566148048backward\n", "\n", "grad= 1.00\n", "\n", "value= 1.00\n", "\n", "-\n", "\n", "\n", "4576520416backward->4566148048backward\n", "\n", "\n", " 1.00\n", "\n", "\n", "4576519936backward\n", "\n", "x1= 1.50\n", "\n", "\n", "4576520464backward->4576519888backward\n", "\n", "\n", "\n", "4576520608backward\n", "\n", "grad=-4.00\n", "\n", "value= 0.00\n", "\n", "+\n", "\n", "\n", "4576520464backward->4576520608backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "4576519984backward\n", "\n", "x2= 2.00\n", "\n", "\n", "4576520080backward->4576519936backward\n", "\n", "\n", "\n", "4576520128backward\n", "\n", "grad=-9.50\n", "\n", "value= 0.00\n", "\n", "a\n", "\n", "\n", "4576520080backward->4576520128backward\n", "\n", "\n", "-1.50\n", "\n", "\n", "4576520608backward->4576520272backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "4576520176backward\n", "\n", "grad=-4.00\n", "\n", "value= 0.00\n", "\n", "*\n", "\n", "\n", "4576520608backward->4576520176backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "4566148048backward->4576520320backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "4566148048backward->4576519840backward\n", "\n", "\n", "\n", "4576520176backward->4576519984backward\n", "\n", "\n", "\n", "4576520176backward->4576520128backward\n", "\n", "\n", "-8.00\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 定义训练数据\n", "x1 = Scalar(1.5, label='x1', requires_grad=False)\n", "y1 = Scalar(1.0, label='y1', requires_grad=False)\n", "x2 = Scalar(2.0, label='x2', requires_grad=False)\n", "y2 = Scalar(4.0, label='y2', requires_grad=False)\n", "# 定义正常的计算图\n", "model = Linear()\n", "k = model.forward(x1)\n", "l = y1 - k\n", "loss = mse([l, model.error(x2, y2)])\n", "# 反向传播\n", "loss.backward()\n", "draw_graph(loss, 'backward')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "4576521232backward\n", "\n", "x1= 1.50\n", "\n", "\n", "4576520224backward\n", "\n", "grad=-1.00\n", "\n", "value= 0.00\n", "\n", "*\n", "\n", "\n", "4576633136backward\n", "\n", "grad= 0.00\n", "\n", "value= 0.00\n", "\n", "+\n", "\n", "\n", "4576520224backward->4576633136backward\n", "\n", "\n", " 0.00\n", "\n", "\n", "4576520032backward\n", "\n", "input= 0.00\n", "\n", "\n", "4576520224backward->4576520032backward\n", "\n", "\n", "\n", "4576632896backward\n", "\n", "grad= 0.00\n", "\n", "value= 0.00\n", "\n", "*\n", "\n", "\n", "4576632896backward->4576521232backward\n", "\n", "\n", "\n", "4576632944backward\n", "\n", "grad=-8.00\n", "\n", "value= 0.00\n", "\n", "a\n", "\n", "\n", "4576632896backward->4576632944backward\n", "\n", "\n", " 0.00\n", "\n", "\n", "4576633472backward\n", "\n", "grad=-4.00\n", "\n", "value= 0.00\n", "\n", "+\n", "\n", "\n", "4576633088backward\n", "\n", "grad=-4.00\n", "\n", "value= 0.00\n", "\n", "b\n", "\n", "\n", "4576633472backward->4576633088backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "4576633280backward\n", "\n", "grad=-4.00\n", "\n", "value= 0.00\n", "\n", "*\n", "\n", "\n", "4576633472backward->4576633280backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "4576632992backward\n", "\n", "grad= 1.00\n", "\n", "value= 1.00\n", "\n", "-\n", "\n", "\n", "4576632992backward->4576520224backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "4576522144backward\n", "\n", "y1= 1.00\n", "\n", "\n", "4576632992backward->4576522144backward\n", "\n", "\n", "\n", "4576518352backward\n", "\n", "x2= 2.00\n", "\n", "\n", "4576633136backward->4576632896backward\n", "\n", "\n", " 0.00\n", "\n", "\n", "4576633136backward->4576633088backward\n", "\n", "\n", " 0.00\n", "\n", "\n", "4576522096backward\n", "\n", "y2= 4.00\n", "\n", "\n", "4576633232backward\n", "\n", "grad= 4.00\n", "\n", "value= 4.00\n", "\n", "-\n", "\n", "\n", "4576633232backward->4576633472backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "4576633232backward->4576522096backward\n", "\n", "\n", "\n", "4576633280backward->4576632944backward\n", "\n", "\n", "-8.00\n", "\n", "\n", "4576633280backward->4576518352backward\n", "\n", "\n", "\n", "4576633328backward\n", "\n", "grad= 1.00\n", "\n", "value= 8.50\n", "\n", "mse\n", "\n", "\n", "4576633328backward->4576632992backward\n", "\n", "\n", " 1.00\n", "\n", "\n", "4576633328backward->4576633232backward\n", "\n", "\n", " 4.00\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x1 = Scalar(1.5, label='x1', requires_grad=False)\n", "y1 = Scalar(1.0, label='y1', requires_grad=False)\n", "x2 = Scalar(2.0, label='x2', requires_grad=False)\n", "y2 = Scalar(4.0, label='y2', requires_grad=False)\n", "model = Linear()\n", "k = model.forward(x1)\n", "# 将k失活\n", "k_out = k * 0\n", "l = y1 - k_out\n", "loss = mse([l, model.error(x2, y2)])\n", "# 反向传播\n", "loss.backward()\n", "draw_graph(loss, 'backward')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "4566380064backward\n", "\n", "input= 0.00\n", "\n", "\n", "4566379584backward\n", "\n", "grad= 0.00\n", "\n", "value= 0.00\n", "\n", "*\n", "\n", "\n", "4566379584backward->4566380064backward\n", "\n", "\n", "\n", "4566379104backward\n", "\n", "input= 1.50\n", "\n", "\n", "4566379584backward->4566379104backward\n", "\n", "\n", "\n", "4566377040backward\n", "\n", "grad=-8.00\n", "\n", "value= 0.00\n", "\n", "a\n", "\n", "\n", "4576518208backward\n", "\n", "grad= 4.00\n", "\n", "value= 4.00\n", "\n", "-\n", "\n", "\n", "4576518448backward\n", "\n", "grad=-4.00\n", "\n", "value= 0.00\n", "\n", "+\n", "\n", "\n", "4576518208backward->4576518448backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "4566380352backward\n", "\n", "y2= 4.00\n", "\n", "\n", "4576518208backward->4566380352backward\n", "\n", "\n", "\n", "4566379152backward\n", "\n", "grad=-1.00\n", "\n", "value= 0.00\n", "\n", "*\n", "\n", "\n", "4566379152backward->4566379584backward\n", "\n", "\n", "\n", "4566379152backward->4566377040backward\n", "\n", "\n", " 0.00\n", "\n", "\n", "4576518304backward\n", "\n", "grad=-4.00\n", "\n", "value= 0.00\n", "\n", "*\n", "\n", "\n", "4576518304backward->4566377040backward\n", "\n", "\n", "-8.00\n", "\n", "\n", "4566379344backward\n", "\n", "x2= 2.00\n", "\n", "\n", "4576518304backward->4566379344backward\n", "\n", "\n", "\n", "4566380256backward\n", "\n", "grad=-1.00\n", "\n", "value= 0.00\n", "\n", "+\n", "\n", "\n", "4566380256backward->4566379152backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "4566380016backward\n", "\n", "grad=-5.00\n", "\n", "value= 0.00\n", "\n", "b\n", "\n", "\n", "4566380256backward->4566380016backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "4576518400backward\n", "\n", "grad= 1.00\n", "\n", "value= 1.00\n", "\n", "-\n", "\n", "\n", "4576518400backward->4566380256backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "4566380304backward\n", "\n", "y1= 1.00\n", "\n", "\n", "4576518400backward->4566380304backward\n", "\n", "\n", "\n", "4576518928backward\n", "\n", "grad= 1.00\n", "\n", "value= 8.50\n", "\n", "mse\n", "\n", "\n", "4576518928backward->4576518208backward\n", "\n", "\n", " 4.00\n", "\n", "\n", "4576518928backward->4576518400backward\n", "\n", "\n", " 1.00\n", "\n", "\n", "4576518448backward->4576518304backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "4576518448backward->4566380016backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 为了减少计算图的歧义,将x1的标签省略掉\n", "x1 = Scalar(1.5, requires_grad=False)\n", "y1 = Scalar(1.0, label='y1', requires_grad=False)\n", "x2 = Scalar(2.0, label='x2', requires_grad=False)\n", "y2 = Scalar(4.0, label='y2', requires_grad=False)\n", "# 将变量x1失活\n", "x1_out = x1 * 0\n", "model = Linear()\n", "loss = mse([model.error(x1_out, y1), model.error(x2, y2)])\n", "# 反向传播\n", "loss.backward()\n", "draw_graph(loss, 'backward')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }