{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 安装第三方库\n", "!pip install pygraphviz" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from utils import Scalar, draw_graph" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "140288582452992forward\n", "\n", "grad=None\n", "\n", "value=2.00\n", "\n", "b\n", "\n", "\n", "140288582453136forward\n", "\n", "grad=None\n", "\n", "value=3.00\n", "\n", "+\n", "\n", "\n", "140288582452992forward->140288582453136forward\n", "\n", "\n", "\n", "\n", "140288582453088forward\n", "\n", "grad=None\n", "\n", "value=1.00\n", "\n", "a\n", "\n", "\n", "140288582453088forward->140288582453136forward\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 简单的计算图\n", "a = Scalar(1.0, label='a')\n", "b = Scalar(2.0, label='b')\n", "c = a + b\n", "draw_graph(c)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "140288582453856forward\n", "\n", "grad=None\n", "\n", "value=2.00\n", "\n", "b\n", "\n", "\n", "140288582452944forward\n", "\n", "grad=None\n", "\n", "value=3.00\n", "\n", "+\n", "\n", "\n", "140288582453856forward->140288582452944forward\n", "\n", "\n", "\n", "\n", "140288582454384forward\n", "\n", "grad=None\n", "\n", "value=12.00\n", "\n", "*\n", "\n", "\n", "140288582454432forward\n", "\n", "grad=None\n", "\n", "value=1.00\n", "\n", "a\n", "\n", "\n", "140288582454432forward->140288582452944forward\n", "\n", "\n", "\n", "\n", "140288582453136forward\n", "\n", "grad=None\n", "\n", "value=4.00\n", "\n", "*\n", "\n", "\n", "140288582454432forward->140288582453136forward\n", "\n", "\n", "\n", "\n", "140288582452944forward->140288582454384forward\n", "\n", "\n", "\n", "\n", "140288582455104forward\n", "\n", "grad=None\n", "\n", "value=4.00\n", "\n", "c\n", "\n", "\n", "140288582455104forward->140288582453136forward\n", "\n", "\n", "\n", "\n", "140288582453136forward->140288582454384forward\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 稍微复杂的计算图\n", "a = Scalar(1.0, label='a')\n", "b = Scalar(2.0, label='b')\n", "c = Scalar(4.0, label='c')\n", "d = a + b\n", "e = a * c\n", "f = d * e\n", "backward_process = f.backward(draw_graph)\n", "draw_graph(f)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "140288582453856backward\n", "\n", "grad=4.00\n", "\n", "value=2.00\n", "\n", "b\n", "\n", "\n", "140288582454384backward\n", "\n", "grad=1.00\n", "\n", "value=12.00\n", "\n", "*\n", "\n", "\n", "140288582452944backward\n", "\n", "grad=4.00\n", "\n", "value=3.00\n", "\n", "+\n", "\n", "\n", "140288582454384backward->140288582452944backward\n", "\n", "\n", "4.00\n", "\n", "\n", "140288582453136backward\n", "\n", "grad=3.00\n", "\n", "value=4.00\n", "\n", "*\n", "\n", "\n", "140288582454384backward->140288582453136backward\n", "\n", "\n", "3.00\n", "\n", "\n", "140288582454432backward\n", "\n", "grad=16.00\n", "\n", "value=1.00\n", "\n", "a\n", "\n", "\n", "140288582452944backward->140288582453856backward\n", "\n", "\n", "4.00\n", "\n", "\n", "140288582452944backward->140288582454432backward\n", "\n", "\n", "4.00\n", "\n", "\n", "140288582455104backward\n", "\n", "grad=3.00\n", "\n", "value=4.00\n", "\n", "c\n", "\n", "\n", "140288582453136backward->140288582454432backward\n", "\n", "\n", "12.00\n", "\n", "\n", "140288582453136backward->140288582455104backward\n", "\n", "\n", "3.00\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "draw_graph(f, 'backward')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# 将反向传播的过程展示出来(可能会有弹框)\n", "for index, pic in enumerate(backward_process):\n", " pic.view(str(index))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }