{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from utils import Scalar, draw_graph\n", "from linear_model import Linear, mse" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "140367041063952backward\n", "\n", "grad=-1.00\n", "\n", "value=0.00\n", "\n", "+\n", "\n", "\n", "140367041063712backward\n", "\n", "grad=-1.00\n", "\n", "value=0.00\n", "\n", "*\n", "\n", "\n", "140367041063952backward->140367041063712backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "140367041063904backward\n", "\n", "grad=-5.00\n", "\n", "value=0.00\n", "\n", "b\n", "\n", "\n", "140367041063952backward->140367041063904backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "140367041063472backward\n", "\n", "y1=1.00\n", "\n", "\n", "140367041064000backward\n", "\n", "grad=1.00\n", "\n", "value=1.00\n", "\n", "-\n", "\n", "\n", "140367041064000backward->140367041063952backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "140367041064000backward->140367041063472backward\n", "\n", "\n", "\n", "140367041063520backward\n", "\n", "y2=4.00\n", "\n", "\n", "140367041063568backward\n", "\n", "x1=1.50\n", "\n", "\n", "140367041064096backward\n", "\n", "grad=4.00\n", "\n", "value=4.00\n", "\n", "-\n", "\n", "\n", "140367041064096backward->140367041063520backward\n", "\n", "\n", "\n", "140367041064288backward\n", "\n", "grad=-4.00\n", "\n", "value=0.00\n", "\n", "+\n", "\n", "\n", "140367041064096backward->140367041064288backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "140367041063616backward\n", "\n", "x2=2.00\n", "\n", "\n", "140367041064144backward\n", "\n", "grad=1.00\n", "\n", "value=8.50\n", "\n", "mse\n", "\n", "\n", "140367041064144backward->140367041064000backward\n", "\n", "\n", "1.00\n", "\n", "\n", "140367041064144backward->140367041064096backward\n", "\n", "\n", "4.00\n", "\n", "\n", "140367041063712backward->140367041063568backward\n", "\n", "\n", "\n", "140367041063760backward\n", "\n", "grad=-9.50\n", "\n", "value=0.00\n", "\n", "a\n", "\n", "\n", "140367041063712backward->140367041063760backward\n", "\n", "\n", "-1.50\n", "\n", "\n", "140367041063808backward\n", "\n", "grad=-4.00\n", "\n", "value=0.00\n", "\n", "*\n", "\n", "\n", "140367041064288backward->140367041063808backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "140367041064288backward->140367041063904backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "140367041063808backward->140367041063616backward\n", "\n", "\n", "\n", "140367041063808backward->140367041063760backward\n", "\n", "\n", "-8.00\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 定义训练数据\n", "x1 = Scalar(1.5, label='x1', requires_grad=False)\n", "y1 = Scalar(1.0, label='y1', requires_grad=False)\n", "x2 = Scalar(2.0, label='x2', requires_grad=False)\n", "y2 = Scalar(4.0, label='y2', requires_grad=False)\n", "# 定义正常的计算图\n", "model = Linear()\n", "k = model.forward(x1)\n", "l = y1 - k\n", "loss = mse([l, model.error(x2, y2)])\n", "# 反向传播\n", "loss.backward()\n", "draw_graph(loss, 'backward')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "140367041130064backward\n", "\n", "grad=0.00\n", "\n", "value=0.00\n", "\n", "*\n", "\n", "\n", "140367041130112backward\n", "\n", "grad=-8.00\n", "\n", "value=0.00\n", "\n", "a\n", "\n", "\n", "140367041130064backward->140367041130112backward\n", "\n", "\n", "0.00\n", "\n", "\n", "140367041129920backward\n", "\n", "x1=1.50\n", "\n", "\n", "140367041130064backward->140367041129920backward\n", "\n", "\n", "\n", "140367041061312backward\n", "\n", "grad=-4.00\n", "\n", "value=0.00\n", "\n", "+\n", "\n", "\n", "140367041130256backward\n", "\n", "grad=-4.00\n", "\n", "value=0.00\n", "\n", "*\n", "\n", "\n", "140367041061312backward->140367041130256backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "140367041130352backward\n", "\n", "grad=-4.00\n", "\n", "value=0.00\n", "\n", "b\n", "\n", "\n", "140367041061312backward->140367041130352backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "140367041130208backward\n", "\n", "grad=1.00\n", "\n", "value=1.00\n", "\n", "-\n", "\n", "\n", "140367041129824backward\n", "\n", "y1=1.00\n", "\n", "\n", "140367041130208backward->140367041129824backward\n", "\n", "\n", "\n", "140367041129344backward\n", "\n", "grad=-1.00\n", "\n", "value=0.00\n", "\n", "*\n", "\n", "\n", "140367041130208backward->140367041129344backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "140367041063664backward\n", "\n", "grad=4.00\n", "\n", "value=4.00\n", "\n", "-\n", "\n", "\n", "140367041063664backward->140367041061312backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "140367041129872backward\n", "\n", "y2=4.00\n", "\n", "\n", "140367041063664backward->140367041129872backward\n", "\n", "\n", "\n", "140367041130256backward->140367041130112backward\n", "\n", "\n", "-8.00\n", "\n", "\n", "140367041129968backward\n", "\n", "x2=2.00\n", "\n", "\n", "140367041130256backward->140367041129968backward\n", "\n", "\n", "\n", "140367041130400backward\n", "\n", "grad=0.00\n", "\n", "value=0.00\n", "\n", "+\n", "\n", "\n", "140367041129344backward->140367041130400backward\n", "\n", "\n", "0.00\n", "\n", "\n", "140367041130448backward\n", "\n", "input=0.00\n", "\n", "\n", "140367041129344backward->140367041130448backward\n", "\n", "\n", "\n", "140367041130400backward->140367041130064backward\n", "\n", "\n", "0.00\n", "\n", "\n", "140367041130400backward->140367041130352backward\n", "\n", "\n", "0.00\n", "\n", "\n", "140367041063856backward\n", "\n", "grad=1.00\n", "\n", "value=8.50\n", "\n", "mse\n", "\n", "\n", "140367041063856backward->140367041130208backward\n", "\n", "\n", "1.00\n", "\n", "\n", "140367041063856backward->140367041063664backward\n", "\n", "\n", "4.00\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x1 = Scalar(1.5, label='x1', requires_grad=False)\n", "y1 = Scalar(1.0, label='y1', requires_grad=False)\n", "x2 = Scalar(2.0, label='x2', requires_grad=False)\n", "y2 = Scalar(4.0, label='y2', requires_grad=False)\n", "model = Linear()\n", "k = model.forward(x1)\n", "# 将k失活\n", "k_out = k * 0\n", "l = y1 - k_out\n", "loss = mse([l, model.error(x2, y2)])\n", "# 反向传播\n", "loss.backward()\n", "draw_graph(loss, 'backward')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "140367041130016backward\n", "\n", "grad=4.00\n", "\n", "value=4.00\n", "\n", "-\n", "\n", "\n", "140367038669008backward\n", "\n", "grad=-4.00\n", "\n", "value=0.00\n", "\n", "+\n", "\n", "\n", "140367041130016backward->140367038669008backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "140367038537632backward\n", "\n", "y2=4.00\n", "\n", "\n", "140367041130016backward->140367038537632backward\n", "\n", "\n", "\n", "140367038669872backward\n", "\n", "grad=0.00\n", "\n", "value=0.00\n", "\n", "*\n", "\n", "\n", "140367038671456backward\n", "\n", "input=0.00\n", "\n", "\n", "140367038669872backward->140367038671456backward\n", "\n", "\n", "\n", "140367038536480backward\n", "\n", "input=1.50\n", "\n", "\n", "140367038669872backward->140367038536480backward\n", "\n", "\n", "\n", "140367038668864backward\n", "\n", "grad=-4.00\n", "\n", "value=0.00\n", "\n", "*\n", "\n", "\n", "140367038669104backward\n", "\n", "grad=-8.00\n", "\n", "value=0.00\n", "\n", "a\n", "\n", "\n", "140367038668864backward->140367038669104backward\n", "\n", "\n", "-8.00\n", "\n", "\n", "140367038669776backward\n", "\n", "x2=2.00\n", "\n", "\n", "140367038668864backward->140367038669776backward\n", "\n", "\n", "\n", "140367038669440backward\n", "\n", "grad=-5.00\n", "\n", "value=0.00\n", "\n", "b\n", "\n", "\n", "140367038670016backward\n", "\n", "grad=1.00\n", "\n", "value=1.00\n", "\n", "-\n", "\n", "\n", "140367038535952backward\n", "\n", "y1=1.00\n", "\n", "\n", "140367038670016backward->140367038535952backward\n", "\n", "\n", "\n", "140367038669296backward\n", "\n", "grad=-1.00\n", "\n", "value=0.00\n", "\n", "+\n", "\n", "\n", "140367038670016backward->140367038669296backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "140367038669008backward->140367038668864backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "140367038669008backward->140367038669440backward\n", "\n", "\n", "-4.00\n", "\n", "\n", "140367041130304backward\n", "\n", "grad=1.00\n", "\n", "value=8.50\n", "\n", "mse\n", "\n", "\n", "140367041130304backward->140367041130016backward\n", "\n", "\n", "4.00\n", "\n", "\n", "140367041130304backward->140367038670016backward\n", "\n", "\n", "1.00\n", "\n", "\n", "140367038671792backward\n", "\n", "grad=-1.00\n", "\n", "value=0.00\n", "\n", "*\n", "\n", "\n", "140367038671792backward->140367038669872backward\n", "\n", "\n", "\n", "140367038671792backward->140367038669104backward\n", "\n", "\n", "0.00\n", "\n", "\n", "140367038669296backward->140367038669440backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "140367038669296backward->140367038671792backward\n", "\n", "\n", "-1.00\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 为了减少计算图的歧义,将x1的标签省略掉\n", "x1 = Scalar(1.5, requires_grad=False)\n", "y1 = Scalar(1.0, label='y1', requires_grad=False)\n", "x2 = Scalar(2.0, label='x2', requires_grad=False)\n", "y2 = Scalar(4.0, label='y2', requires_grad=False)\n", "# 将变量x1失活\n", "x1_out = x1 * 0\n", "model = Linear()\n", "loss = mse([model.error(x1_out, y1), model.error(x2, y2)])\n", "# 反向传播\n", "loss.backward()\n", "draw_graph(loss, 'backward')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }