{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pygraphviz in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (1.11)\r\n" ] } ], "source": [ "!pip install pygraphviz" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from utils import Scalar, draw_graph" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "4416457600forward\n", "\n", "grad=None\n", "\n", "value= 2.00\n", "\n", "b\n", "\n", "\n", "4416458128forward\n", "\n", "grad=None\n", "\n", "value= 3.00\n", "\n", "+\n", "\n", "\n", "4416457600forward->4416458128forward\n", "\n", "\n", "\n", "\n", "4416457264forward\n", "\n", "grad=None\n", "\n", "value= 1.00\n", "\n", "a\n", "\n", "\n", "4416457264forward->4416458128forward\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 简单的计算图\n", "a = Scalar(1.0, label='a')\n", "b = Scalar(2.0, label='b')\n", "c = a + b\n", "draw_graph(c)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "4416480256forward\n", "\n", "grad=None\n", "\n", "value= 2.00\n", "\n", "b\n", "\n", "\n", "4416457264forward\n", "\n", "grad=None\n", "\n", "value= 3.00\n", "\n", "+\n", "\n", "\n", "4416480256forward->4416457264forward\n", "\n", "\n", "\n", "\n", "4416480304forward\n", "\n", "grad=None\n", "\n", "value= 12.00\n", "\n", "*\n", "\n", "\n", "4416457264forward->4416480304forward\n", "\n", "\n", "\n", "\n", "4416480352forward\n", "\n", "grad=None\n", "\n", "value= 1.00\n", "\n", "a\n", "\n", "\n", "4416480352forward->4416457264forward\n", "\n", "\n", "\n", "\n", "4416458128forward\n", "\n", "grad=None\n", "\n", "value= 4.00\n", "\n", "*\n", "\n", "\n", "4416480352forward->4416458128forward\n", "\n", "\n", "\n", "\n", "4416480400forward\n", "\n", "grad=None\n", "\n", "value= 4.00\n", "\n", "c\n", "\n", "\n", "4416480400forward->4416458128forward\n", "\n", "\n", "\n", "\n", "4416458128forward->4416480304forward\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 稍微复杂的计算图\n", "a = Scalar(1.0, label='a')\n", "b = Scalar(2.0, label='b')\n", "c = Scalar(4.0, label='c')\n", "d = a + b\n", "e = a * c\n", "f = d * e\n", "backward_process = f.backward(draw_graph)\n", "draw_graph(f)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "4416480256backward\n", "\n", "grad= 4.00\n", "\n", "value= 2.00\n", "\n", "b\n", "\n", "\n", "4416480304backward\n", "\n", "grad= 1.00\n", "\n", "value= 12.00\n", "\n", "*\n", "\n", "\n", "4416457264backward\n", "\n", "grad= 4.00\n", "\n", "value= 3.00\n", "\n", "+\n", "\n", "\n", "4416480304backward->4416457264backward\n", "\n", "\n", " 4.00\n", "\n", "\n", "4416458128backward\n", "\n", "grad= 3.00\n", "\n", "value= 4.00\n", "\n", "*\n", "\n", "\n", "4416480304backward->4416458128backward\n", "\n", "\n", " 3.00\n", "\n", "\n", "4416457264backward->4416480256backward\n", "\n", "\n", " 4.00\n", "\n", "\n", "4416480352backward\n", "\n", "grad= 16.00\n", "\n", "value= 1.00\n", "\n", "a\n", "\n", "\n", "4416457264backward->4416480352backward\n", "\n", "\n", " 4.00\n", "\n", "\n", "4416480400backward\n", "\n", "grad= 3.00\n", "\n", "value= 4.00\n", "\n", "c\n", "\n", "\n", "4416458128backward->4416480352backward\n", "\n", "\n", " 12.00\n", "\n", "\n", "4416458128backward->4416480400backward\n", "\n", "\n", " 3.00\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "draw_graph(f, 'backward')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# 将反向传播的过程展示出来\n", "for index, pic in enumerate(backward_process):\n", " pic.view(str(index))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }