utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. # -*- coding: UTF-8 -*-
  2. '''
  3. 此脚本用于定义Scalar类,以及相应的可视化工具
  4. '''
  5. from graphviz import Digraph
  6. import math
  7. class Scalar:
  8. def __init__(self, value, prevs=[], op=None, label='', requires_grad=True):
  9. # 节点的值
  10. self.value = value
  11. # 节点的标识(label)和对应的运算(op),用于作图
  12. self.label = label
  13. self.op = op
  14. # 节点的前节点,即当前节点是运算的结果,而前节点是参与运算的量
  15. self.prevs = prevs
  16. # 是否需要计算该节点偏导数,即∂loss/∂self(loss表示最后的模型损失)
  17. self.requires_grad = requires_grad
  18. # 该节点偏导数,即∂loss/∂self
  19. self.grad = 0.0
  20. # 如果该节点的prevs非空,存储所有的∂self/∂prev
  21. self.grad_wrt = dict()
  22. # 作图需要,实际上对计算没有作用
  23. self.back_prop = dict()
  24. def __repr__(self):
  25. return f'Scalar(value={self.value:.2f}, grad={self.grad:.2f})'
  26. def __add__(self, other):
  27. '''
  28. 定义加法,self + other将触发该函数
  29. '''
  30. if not isinstance(other, Scalar):
  31. other = Scalar(other, requires_grad=False)
  32. # output = self + other
  33. output = Scalar(self.value + other.value, [self, other], '+')
  34. output.requires_grad = self.requires_grad or other.requires_grad
  35. # 计算偏导数 ∂output/∂self = 1
  36. output.grad_wrt[self] = 1
  37. # 计算偏导数 ∂output/∂other = 1
  38. output.grad_wrt[other] = 1
  39. return output
  40. def __sub__(self, other):
  41. '''
  42. 定义减法,self - other将触发该函数
  43. '''
  44. if not isinstance(other, Scalar):
  45. other = Scalar(other, requires_grad=False)
  46. # output = self - other
  47. output = Scalar(self.value - other.value, [self, other], '-')
  48. output.requires_grad = self.requires_grad or other.requires_grad
  49. # 计算偏导数 ∂output/∂self = 1
  50. output.grad_wrt[self] = 1
  51. # 计算偏导数 ∂output/∂other = -1
  52. output.grad_wrt[other] = -1
  53. return output
  54. def __mul__(self, other):
  55. '''
  56. 定义乘法,self * other将触发该函数
  57. '''
  58. if not isinstance(other, Scalar):
  59. other = Scalar(other, requires_grad=False)
  60. # output = self * other
  61. output = Scalar(self.value * other.value, [self, other], '*')
  62. output.requires_grad = self.requires_grad or other.requires_grad
  63. # 计算偏导数 ∂output/∂self = other
  64. output.grad_wrt[self] = other.value
  65. # 计算偏导数 ∂output/∂other = self
  66. output.grad_wrt[other] = self.value
  67. return output
  68. def __pow__(self, other):
  69. '''
  70. 定义乘方,self**other将触发该函数
  71. '''
  72. assert isinstance(other, (int, float))
  73. # output = self ** other
  74. output = Scalar(self.value ** other, [self], f'^{other}')
  75. output.requires_grad = self.requires_grad
  76. # 计算偏导数 ∂output/∂self = other * self**(other-1)
  77. output.grad_wrt[self] = other * self.value**(other - 1)
  78. return output
  79. def sigmoid(self):
  80. '''
  81. 定义sigmoid
  82. '''
  83. s = 1 / (1 + math.exp(-1 * self.value))
  84. output = Scalar(s, [self], 'sigmoid')
  85. output.requires_grad = self.requires_grad
  86. # 计算偏导数 ∂output/∂self = output * (1 - output)
  87. output.grad_wrt[self] = s * (1 - s)
  88. return output
  89. def __rsub__(self, other):
  90. '''
  91. 定义右减法,other - self将触发该函数
  92. '''
  93. if not isinstance(other, Scalar):
  94. other = Scalar(other, requires_grad=False)
  95. output = Scalar(other.value - self.value, [self, other], '-')
  96. output.requires_grad = self.requires_grad or other.requires_grad
  97. # 计算偏导数 ∂output/∂self = -1
  98. output.grad_wrt[self] = -1
  99. # 计算偏导数 ∂output/∂other = 1
  100. output.grad_wrt[other] = 1
  101. return output
  102. def __radd__(self, other):
  103. '''
  104. 定义右加法,other + self将触发该函数
  105. '''
  106. return self.__add__(other)
  107. def __rmul__(self, other):
  108. '''
  109. 定义右乘法,other * self将触发该函数
  110. '''
  111. return self * other
  112. def backward(self, fn=None):
  113. '''
  114. 由当前节点出发,求解以当前节点为顶点的计算图中每个节点的偏导数,i.e. ∂self/∂node
  115. 参数
  116. ----
  117. fn :画图函数,如果该变量不等于None,则会返回向后传播每一步的计算的记录
  118. 返回
  119. ----
  120. re :向后传播每一步的计算的记录
  121. '''
  122. def _topological_order():
  123. '''
  124. 利用深度优先算法,返回计算图的拓扑排序(topological sorting)
  125. '''
  126. def _add_prevs(node):
  127. if node not in visited:
  128. visited.add(node)
  129. for prev in node.prevs:
  130. _add_prevs(prev)
  131. ordered.append(node)
  132. ordered, visited = [], set()
  133. _add_prevs(self)
  134. return ordered
  135. def _compute_grad_of_prevs(node):
  136. '''
  137. 由node节点出发,向后传播
  138. '''
  139. # 作图需要,实际上对计算没有作用
  140. node.back_prop = dict()
  141. # 得到当前节点在计算图中的梯度。由于一个节点可以在多个计算图中出现,
  142. # 使用cg_grad记录当前计算图的梯度
  143. dnode = cg_grad[node]
  144. # 使用node.grad记录节点的累积梯度
  145. node.grad += dnode
  146. for prev in node.prevs:
  147. # 由于node节点的偏导数已经计算完成,可以向后扩散(反向传播)
  148. # 需要注意的是,向后扩散到上游节点是累加关系
  149. grad_spread = dnode * node.grad_wrt[prev]
  150. cg_grad[prev] = cg_grad.get(prev, 0.0) + grad_spread
  151. node.back_prop[prev] = node.back_prop.get(prev, 0.0) + grad_spread
  152. # 当前节点的偏导数等于1,因为∂self/∂self = 1。这是反向传播算法的起点
  153. cg_grad = {self: 1}
  154. # 为了计算每个节点的偏导数,需要使用拓扑排序的倒序来遍历计算图
  155. ordered = reversed(_topological_order())
  156. re = []
  157. for node in ordered:
  158. _compute_grad_of_prevs(node)
  159. # 作图需要,实际上对计算没有作用
  160. if fn is not None:
  161. re.append(fn(self, 'backward'))
  162. return re
  163. def _get_node_attr(node, direction='forward'):
  164. '''
  165. 节点的属性
  166. '''
  167. node_type = _get_node_type(node)
  168. def _forward_attr():
  169. if node_type == 'param':
  170. node_text = f'{{ grad=None | value={node.value: .2f} | {node.label}}}'
  171. return dict(label=node_text, shape='record', fontsize='10', fillcolor='springgreen', style='filled, bold')
  172. elif node_type == 'computation':
  173. node_text = f'{{ grad=None | value={node.value: .2f} | {node.op}}}'
  174. return dict(label=node_text, shape='record', fontsize='10', fillcolor='gray94', style='filled, rounded')
  175. elif node_type == 'input':
  176. if node.label == '':
  177. node_text = f'input={node.value: .2f}'
  178. else:
  179. node_text = f'{node.label}={node.value: .2f}'
  180. return dict(label=node_text, shape='oval', fontsize='10')
  181. def _backward_attr():
  182. attr = _forward_attr()
  183. attr['label'] = attr['label'].replace('grad=None', f'grad={node.grad: .2f}')
  184. if not node.requires_grad:
  185. attr['style'] = 'dashed'
  186. # 为了作图美观
  187. # 如果向后扩散(反向传播)的梯度等于0,或者扩散给不需要梯度的节点,那么该节点用虚线表示
  188. grad_back = [v if k.requires_grad else 0 for (k, v) in node.back_prop.items()]
  189. if len(grad_back) > 0 and sum(grad_back) == 0:
  190. attr['style'] = 'dashed'
  191. return attr
  192. if direction == 'forward':
  193. return _forward_attr()
  194. else:
  195. return _backward_attr()
  196. def _get_node_type(node):
  197. '''
  198. 决定节点的类型,计算节点、参数以及输入数据
  199. '''
  200. if node.op is not None:
  201. return 'computation'
  202. if node.requires_grad:
  203. return 'param'
  204. return 'input'
  205. def _trace(root):
  206. '''
  207. 遍历图中的所有点和边
  208. '''
  209. nodes, edges = set(), set()
  210. def _build(v):
  211. if v not in nodes:
  212. nodes.add(v)
  213. for prev in v.prevs:
  214. edges.add((prev, v))
  215. _build(prev)
  216. _build(root)
  217. return nodes, edges
  218. def _draw_node(graph, node, direction='forward'):
  219. '''
  220. 画节点
  221. '''
  222. node_attr = _get_node_attr(node, direction)
  223. uid = str(id(node)) + direction
  224. graph.node(name=uid, **node_attr)
  225. def _draw_edge(graph, n1, n2, direction='forward'):
  226. '''
  227. 画边
  228. '''
  229. uid1 = str(id(n1)) + direction
  230. uid2 = str(id(n2)) + direction
  231. def _draw_back_edge():
  232. if n1.requires_grad and n2.requires_grad:
  233. grad = n2.back_prop.get(n1, None)
  234. if grad is None:
  235. graph.edge(uid2, uid1, arrowhead='none', color='deepskyblue')
  236. elif grad == 0:
  237. graph.edge(uid2, uid1, style='dashed', label=f'{grad: .2f}', color='deepskyblue')
  238. else:
  239. graph.edge(uid2, uid1, label=f'{grad: .2f}', color='deepskyblue')
  240. else:
  241. graph.edge(uid2, uid1, style='dashed', arrowhead='none', color='deepskyblue')
  242. if direction == 'forward':
  243. graph.edge(uid1, uid2)
  244. elif direction == 'backward':
  245. _draw_back_edge()
  246. else:
  247. _draw_back_edge()
  248. graph.edge(uid1, uid2)
  249. def draw_graph(root, direction='forward'):
  250. '''
  251. 图形化展示由root为顶点的计算图
  252. 参数
  253. ----
  254. root :Scalar,计算图的顶点
  255. direction :str,向前传播(forward)或者反向传播(backward)
  256. 返回
  257. ----
  258. re :Digraph,计算图
  259. '''
  260. nodes, edges = _trace(root)
  261. rankdir = 'BT' if direction == 'forward' else 'TB'
  262. graph = Digraph(format='svg', graph_attr={'rankdir': rankdir})
  263. for item in nodes:
  264. _draw_node(graph, item, direction)
  265. for n1, n2 in edges:
  266. _draw_edge(graph, n1, n2, direction)
  267. return graph