{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n",
"\u001b[0mRequirement already satisfied: pygraphviz in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (1.11)\n",
"\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mDEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
]
}
],
"source": [
"!pip install pygraphviz"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"class ScalarTmp:\n",
" \n",
" def __init__(self, value, prevs=[], op=None, label=''):\n",
" # 定义节点的值\n",
" self.value = value\n",
" # 定义直接的前序节点\n",
" self.prevs = prevs\n",
" # 定义运算符号:op或者变量名label\n",
" self.op = op\n",
" self.label = label\n",
" self.grad = 0.0\n",
" self.grad_wrt = {}\n",
" \n",
" def __repr__(self):\n",
" return f'{self.value} | {self.op} | {self.label}'\n",
" \n",
" def __add__(self, other):\n",
" # self + other触发这个函数\n",
" value = self.value + other.value\n",
" prevs = [self, other]\n",
" output = ScalarTmp(value, prevs, op='+')\n",
" output.grad_wrt[self] = 1\n",
" output.grad_wrt[other] = 1\n",
" return output\n",
" \n",
" def __mul__(self, other):\n",
" # self * other触发这个函数\n",
" value = self.value * other.value\n",
" prevs = [self, other]\n",
" output = ScalarTmp(value, prevs, op='*')\n",
" output.grad_wrt[self] = other.value\n",
" output.grad_wrt[other] = self.value\n",
" return output"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3.0 | + | "
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = ScalarTmp(1.0, label='a')\n",
"b = ScalarTmp(2.0, label='b')\n",
"c = a + b\n",
"c"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"from graphviz import Digraph\n",
"\n",
"def _trace(root):\n",
" # 遍历计算图中的所有点和边\n",
" nodes, edges = set(), set()\n",
" def _build(v):\n",
" if v not in nodes:\n",
" nodes.add(v)\n",
" for prev in v.prevs:\n",
" edges.add((prev, v))\n",
" _build(prev)\n",
" _build(root)\n",
" return nodes, edges\n",
" \n",
"\n",
"def draw_graph(root, direction='forward'):\n",
" nodes, edges = _trace(root)\n",
" rankdir = 'BT' if direction == 'forward' else 'TB'\n",
" graph = Digraph(format='svg', graph_attr={'rankdir': rankdir})\n",
" # 画点\n",
" for node in nodes:\n",
" label = node.label if node.op is None else node.op\n",
" node_attr = f'{{ grad={node.grad:.2f} | value={node.value:.2f} | {label}}}'\n",
" uid = str(id(node))\n",
" graph.node(name=uid, label=node_attr, shape='record')\n",
" # 画边\n",
" for edge in edges:\n",
" id1 = str(id(edge[0]))\n",
" id2 = str(id(edge[1]))\n",
" graph.edge(id1, id2)\n",
" return graph"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"draw_graph(c)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = ScalarTmp(1.0, label='a')\n",
"b = ScalarTmp(2.0, label='b')\n",
"c = ScalarTmp(4.0, label='c')\n",
"d = a + b\n",
"e = a * c\n",
"f = d * e\n",
"draw_graph(f)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "unsupported operand type(s) for -: 'ScalarTmp' and 'ScalarTmp'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0ma\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m: unsupported operand type(s) for -: 'ScalarTmp' and 'ScalarTmp'"
]
}
],
"source": [
"a - b"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# 拓扑排序\n",
"def _top_order(root):\n",
" # 利用深度优先搜索\n",
" ordered, visited = [], set()\n",
" def _add_prevs(node):\n",
" if node not in visited:\n",
" visited.add(node)\n",
" for prev in node.prevs:\n",
" _add_prevs(prev)\n",
" ordered.append(node)\n",
" _add_prevs(root)\n",
" return ordered"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[1.0 | None | a,\n",
" 2.0 | None | b,\n",
" 3.0 | + | ,\n",
" 4.0 | None | c,\n",
" 4.0 | * | ,\n",
" 12.0 | * | ]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"_top_order(f)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"def backward(root):\n",
" # 定义顶点的梯度等于1\n",
" root.grad = 1.0\n",
" ordered = _top_order(root)\n",
" for node in reversed(ordered):\n",
" for v in node.prevs:\n",
" v.grad += node.grad * node.grad_wrt[v]\n",
" return root"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"12.0 | * | "
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"backward(f)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"draw_graph(f)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import os\n",
"sys.path.append(os.path.abspath(os.path.join('../ch07_autograd')))\n",
"from utils import Scalar, draw_graph"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"a = Scalar(1.0, label='a')\n",
"b = Scalar(2.0, label='b')\n",
"c = Scalar(4.0, label='c')\n",
"d = a + b\n",
"e = a * c\n",
"f = d * e\n",
"backward_process = f.backward(draw_graph)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"draw_graph(f, 'backward')"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n",
"\u001b[0mRequirement already satisfied: IPython in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (7.19.0)\n",
"Requirement already satisfied: setuptools>=18.5 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from IPython) (68.1.2)\n",
"Requirement already satisfied: jedi>=0.10 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from IPython) (0.17.1)\n",
"Requirement already satisfied: decorator in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from IPython) (4.4.2)\n",
"Requirement already satisfied: pickleshare in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from IPython) (0.7.5)\n",
"Requirement already satisfied: traitlets>=4.2 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from IPython) (5.0.5)\n",
"Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from IPython) (3.0.8)\n",
"Requirement already satisfied: pygments in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from IPython) (2.16.1)\n",
"Requirement already satisfied: backcall in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from IPython) (0.2.0)\n",
"Requirement already satisfied: pexpect>4.3 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from IPython) (4.8.0)\n",
"Requirement already satisfied: appnope in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from IPython) (0.1.0)\n",
"Requirement already satisfied: parso<0.8.0,>=0.7.0 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from jedi>=0.10->IPython) (0.7.0)\n",
"Requirement already satisfied: ptyprocess>=0.5 in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from pexpect>4.3->IPython) (0.6.0)\n",
"Requirement already satisfied: wcwidth in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->IPython) (0.2.5)\n",
"Requirement already satisfied: ipython-genutils in /Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages (from traitlets>=4.2->IPython) (0.2.0)\n",
"\u001b[33mWARNING: Ignoring invalid distribution -y-mini-racer (/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages)\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mDEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063\u001b[0m\u001b[33m\n",
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
]
}
],
"source": [
"!pip install IPython"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import clear_output\n",
"import time"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for i in backward_process:\n",
" display(i)\n",
" time.sleep(3)\n",
" clear_output(wait=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}