{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n", "\u001b[0mRequirement already satisfied: pygraphviz in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (1.11)\n", "\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33mDEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063\u001b[0m\u001b[33m\n", "\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" ] } ], "source": [ "!pip install pygraphviz" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "class ScalarTmp:\n", " \n", " def __init__(self, value, prevs=[], op=None, label=''):\n", " # 定义节点的值\n", " self.value = value\n", " # 定义直接的前序节点\n", " self.prevs = prevs\n", " # 定义运算符号:op或者变量名label\n", " self.op = op\n", " self.label = label\n", " self.grad = 0.0\n", " self.grad_wrt = {}\n", " \n", " def __repr__(self):\n", " return f'{self.value} | {self.op} | {self.label}'\n", " \n", " def __add__(self, other):\n", " # self + other触发这个函数\n", " value = self.value + other.value\n", " prevs = [self, other]\n", " output = ScalarTmp(value, prevs, op='+')\n", " output.grad_wrt[self] = 1\n", " output.grad_wrt[other] = 1\n", " return output\n", " \n", " def __mul__(self, other):\n", " # self * other触发这个函数\n", " value = self.value * other.value\n", " prevs = [self, other]\n", " output = ScalarTmp(value, prevs, op='*')\n", " output.grad_wrt[self] = other.value\n", " output.grad_wrt[other] = self.value\n", " return output" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3.0 | + | " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = ScalarTmp(1.0, label='a')\n", "b = ScalarTmp(2.0, label='b')\n", "c = a + b\n", "c" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "from graphviz import Digraph\n", "\n", "def _trace(root):\n", " # 遍历计算图中的所有点和边\n", " nodes, edges = set(), set()\n", " def _build(v):\n", " if v not in nodes:\n", " nodes.add(v)\n", " for prev in v.prevs:\n", " edges.add((prev, v))\n", " _build(prev)\n", " _build(root)\n", " return nodes, edges\n", " \n", "\n", "def draw_graph(root, direction='forward'):\n", " nodes, edges = _trace(root)\n", " rankdir = 'BT' if direction == 'forward' else 'TB'\n", " graph = Digraph(format='svg', graph_attr={'rankdir': rankdir})\n", " # 画点\n", " for node in nodes:\n", " label = node.label if node.op is None else node.op\n", " node_attr = f'{{ grad={node.grad:.2f} | value={node.value:.2f} | {label}}}'\n", " uid = str(id(node))\n", " graph.node(name=uid, label=node_attr, shape='record')\n", " # 画边\n", " for edge in edges:\n", " id1 = str(id(edge[0]))\n", " id2 = str(id(edge[1]))\n", " graph.edge(id1, id2)\n", " return graph" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "140543578237728\n", "\n", "grad=None\n", "\n", "value=1.00\n", "\n", "a\n", "\n", "\n", "140543578236336\n", "\n", "grad=None\n", "\n", "value=3.00\n", "\n", "+\n", "\n", "\n", "140543578237728->140543578236336\n", "\n", "\n", "\n", "\n", "140543578238880\n", "\n", "grad=None\n", "\n", "value=2.00\n", "\n", "b\n", "\n", "\n", "140543578238880->140543578236336\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "draw_graph(c)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "140543573436544\n", "\n", "grad=None\n", "\n", "value=12.00\n", "\n", "*\n", "\n", "\n", "140543573436592\n", "\n", "grad=None\n", "\n", "value=4.00\n", "\n", "c\n", "\n", "\n", "140543573436688\n", "\n", "grad=None\n", "\n", "value=4.00\n", "\n", "*\n", "\n", "\n", "140543573436592->140543573436688\n", "\n", "\n", "\n", "\n", "140543572439776\n", "\n", "grad=None\n", "\n", "value=1.00\n", "\n", "a\n", "\n", "\n", "140543572439776->140543573436688\n", "\n", "\n", "\n", "\n", "140543572441984\n", "\n", "grad=None\n", "\n", "value=3.00\n", "\n", "+\n", "\n", "\n", "140543572439776->140543572441984\n", "\n", "\n", "\n", "\n", "140543573436688->140543573436544\n", "\n", "\n", "\n", "\n", "140543572441984->140543573436544\n", "\n", "\n", "\n", "\n", "140543572440448\n", "\n", "grad=None\n", "\n", "value=2.00\n", "\n", "b\n", "\n", "\n", "140543572440448->140543572441984\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = ScalarTmp(1.0, label='a')\n", "b = ScalarTmp(2.0, label='b')\n", "c = ScalarTmp(4.0, label='c')\n", "d = a + b\n", "e = a * c\n", "f = d * e\n", "draw_graph(f)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "unsupported operand type(s) for -: 'ScalarTmp' and 'ScalarTmp'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0ma\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mTypeError\u001b[0m: unsupported operand type(s) for -: 'ScalarTmp' and 'ScalarTmp'" ] } ], "source": [ "a - b" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# 拓扑排序\n", "def _top_order(root):\n", " # 利用深度优先搜索\n", " ordered, visited = [], set()\n", " def _add_prevs(node):\n", " if node not in visited:\n", " visited.add(node)\n", " for prev in node.prevs:\n", " _add_prevs(prev)\n", " ordered.append(node)\n", " _add_prevs(root)\n", " return ordered" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[1.0 | None | a,\n", " 2.0 | None | b,\n", " 3.0 | + | ,\n", " 4.0 | None | c,\n", " 4.0 | * | ,\n", " 12.0 | * | ]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "_top_order(f)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "def backward(root):\n", " # 定义顶点的梯度等于1\n", " root.grad = 1.0\n", " ordered = _top_order(root)\n", " for node in reversed(ordered):\n", " for v in node.prevs:\n", " v.grad += node.grad * node.grad_wrt[v]\n", " return root" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "12.0 | * | " ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "backward(f)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "140543573436544\n", "\n", "grad=1.00\n", "\n", "value=12.00\n", "\n", "*\n", "\n", "\n", "140543573436592\n", "\n", "grad=3.00\n", "\n", "value=4.00\n", "\n", "c\n", "\n", "\n", "140543573436688\n", "\n", "grad=3.00\n", "\n", "value=4.00\n", "\n", "*\n", "\n", "\n", "140543573436592->140543573436688\n", "\n", "\n", "\n", "\n", "140543572439776\n", "\n", "grad=16.00\n", "\n", "value=1.00\n", "\n", "a\n", "\n", "\n", "140543572439776->140543573436688\n", "\n", "\n", "\n", "\n", "140543572441984\n", "\n", "grad=4.00\n", "\n", "value=3.00\n", "\n", "+\n", "\n", "\n", "140543572439776->140543572441984\n", "\n", "\n", "\n", "\n", "140543573436688->140543573436544\n", "\n", "\n", "\n", "\n", "140543572441984->140543573436544\n", "\n", "\n", "\n", "\n", "140543572440448\n", "\n", "grad=4.00\n", "\n", "value=2.00\n", "\n", "b\n", "\n", "\n", "140543572440448->140543572441984\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "draw_graph(f)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "import sys\n", "import os\n", "sys.path.append(os.path.abspath(os.path.join('../ch07_autograd')))\n", "from utils import Scalar, draw_graph" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "a = Scalar(1.0, label='a')\n", "b = Scalar(2.0, label='b')\n", "c = Scalar(4.0, label='c')\n", "d = a + b\n", "e = a * c\n", "f = d * e\n", "backward_process = f.backward(draw_graph)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "140543023688240backward\n", "\n", "grad=3.00\n", "\n", "value=4.00\n", "\n", "c\n", "\n", "\n", "140543023687856backward\n", "\n", "grad=3.00\n", "\n", "value=4.00\n", "\n", "*\n", "\n", "\n", "140543023687856backward->140543023688240backward\n", "\n", "\n", "3.00\n", "\n", "\n", "140543023688144backward\n", "\n", "grad=16.00\n", "\n", "value=1.00\n", "\n", "a\n", "\n", "\n", "140543023687856backward->140543023688144backward\n", "\n", "\n", "12.00\n", "\n", "\n", "140543023687376backward\n", "\n", "grad=4.00\n", "\n", "value=3.00\n", "\n", "+\n", "\n", "\n", "140543023688480backward\n", "\n", "grad=4.00\n", "\n", "value=2.00\n", "\n", "b\n", "\n", "\n", "140543023687376backward->140543023688480backward\n", "\n", "\n", "4.00\n", "\n", "\n", "140543023687376backward->140543023688144backward\n", "\n", "\n", "4.00\n", "\n", "\n", "140543023689632backward\n", "\n", "grad=1.00\n", "\n", "value=12.00\n", "\n", "*\n", "\n", "\n", "140543023689632backward->140543023687856backward\n", "\n", "\n", "3.00\n", "\n", "\n", "140543023689632backward->140543023687376backward\n", "\n", "\n", "4.00\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "draw_graph(f, 'backward')" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n", "\u001b[0mRequirement already satisfied: IPython in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (7.19.0)\n", "Requirement already satisfied: setuptools>=18.5 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from IPython) (68.1.2)\n", "Requirement already satisfied: jedi>=0.10 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from IPython) (0.17.1)\n", "Requirement already satisfied: decorator in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from IPython) (4.4.2)\n", "Requirement already satisfied: pickleshare in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from IPython) (0.7.5)\n", "Requirement already satisfied: traitlets>=4.2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from IPython) (5.0.5)\n", "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from IPython) (3.0.8)\n", "Requirement already satisfied: pygments in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from IPython) (2.16.1)\n", "Requirement already satisfied: backcall in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from IPython) (0.2.0)\n", "Requirement already satisfied: pexpect>4.3 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from IPython) (4.8.0)\n", "Requirement already satisfied: appnope in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from IPython) (0.1.0)\n", "Requirement already satisfied: parso<0.8.0,>=0.7.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from jedi>=0.10->IPython) (0.7.0)\n", "Requirement already satisfied: ptyprocess>=0.5 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from pexpect>4.3->IPython) (0.6.0)\n", "Requirement already satisfied: wcwidth in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->IPython) (0.2.5)\n", "Requirement already satisfied: ipython-genutils in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from traitlets>=4.2->IPython) (0.2.0)\n", "\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33mDEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063\u001b[0m\u001b[33m\n", "\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" ] } ], "source": [ "!pip install IPython" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "from IPython.display import clear_output\n", "import time" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "%3\n", "\n", "\n", "140543023688240backward\n", "\n", "grad=3.00\n", "\n", "value=4.00\n", "\n", "c\n", "\n", "\n", "140543023687856backward\n", "\n", "grad=3.00\n", "\n", "value=4.00\n", "\n", "*\n", "\n", "\n", "140543023687856backward->140543023688240backward\n", "\n", "\n", "3.00\n", "\n", "\n", "140543023688144backward\n", "\n", "grad=16.00\n", "\n", "value=1.00\n", "\n", "a\n", "\n", "\n", "140543023687856backward->140543023688144backward\n", "\n", "\n", "12.00\n", "\n", "\n", "140543023687376backward\n", "\n", "grad=4.00\n", "\n", "value=3.00\n", "\n", "+\n", "\n", "\n", "140543023688480backward\n", "\n", "grad=4.00\n", "\n", "value=2.00\n", "\n", "b\n", "\n", "\n", "140543023687376backward->140543023688480backward\n", "\n", "\n", "4.00\n", "\n", "\n", "140543023687376backward->140543023688144backward\n", "\n", "\n", "4.00\n", "\n", "\n", "140543023689632backward\n", "\n", "grad=1.00\n", "\n", "value=12.00\n", "\n", "*\n", "\n", "\n", "140543023689632backward->140543023687856backward\n", "\n", "\n", "3.00\n", "\n", "\n", "140543023689632backward->140543023687376backward\n", "\n", "\n", "4.00\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "for i in backward_process:\n", " display(i)\n", " time.sleep(3)\n", " clear_output(wait=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }