utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. # -*- coding: UTF-8 -*-
  2. """
  3. 此脚本用于定义Scalar类,以及相应的可视化工具
  4. """
  5. from graphviz import Digraph
  6. import math
  7. class Scalar:
  8. def __init__(self, value, prevs=[], op=None, label='', requires_grad=True):
  9. # 节点的值
  10. self.value = value
  11. # 节点的标识(label)和对应的运算(op),用于作图
  12. self.label = label
  13. self.op = op
  14. # 节点的前节点
  15. self.prevs = prevs
  16. # 是否需要计算该节点偏导数,即∂loss/∂self
  17. self.requires_grad = requires_grad
  18. # 存储该节点偏导数,即∂loss/∂self
  19. self.grad = 0.0
  20. # 如果该节点的prevs非空,存储所有的∂self/∂prev
  21. self.grad_wrt = dict()
  22. # 作图需要,实际上对计算没有作用
  23. self.back_prop = dict()
  24. def __repr__(self):
  25. return f'Scalar(value={self.value:.2f}, grad={self.grad:.2f})'
  26. def __add__(self, other):
  27. """
  28. 定义加法,self + other将触发该函数
  29. """
  30. if not isinstance(other, Scalar):
  31. other = Scalar(other, requires_grad=False)
  32. # output = self + other
  33. output = Scalar(self.value + other.value, [self, other], '+')
  34. output.requires_grad = self.requires_grad or other.requires_grad
  35. # 计算偏导数 ∂output/∂self = 1
  36. output.grad_wrt[self] = 1
  37. # 计算偏导数 ∂output/∂other = 1
  38. output.grad_wrt[other] = 1
  39. return output
  40. def __sub__(self, other):
  41. """
  42. 定义减法,self - other将触发该函数
  43. """
  44. if not isinstance(other, Scalar):
  45. other = Scalar(other, requires_grad=False)
  46. # output = self - other
  47. output = Scalar(self.value - other.value, [self, other], '-')
  48. output.requires_grad = self.requires_grad or other.requires_grad
  49. # 计算偏导数 ∂output/∂self = 1
  50. output.grad_wrt[self] = 1
  51. # 计算偏导数 ∂output/∂other = -1
  52. output.grad_wrt[other] = -1
  53. return output
  54. def __mul__(self, other):
  55. """
  56. 定义乘法,self * other将触发该函数
  57. """
  58. if not isinstance(other, Scalar):
  59. other = Scalar(other, requires_grad=False)
  60. # output = self * other
  61. output = Scalar(self.value * other.value, [self, other], '*')
  62. output.requires_grad = self.requires_grad or other.requires_grad
  63. # 计算偏导数 ∂output/∂self = other
  64. output.grad_wrt[self] = other.value
  65. # 计算偏导数 ∂output/∂other = self
  66. output.grad_wrt[other] = self.value
  67. return output
  68. def __pow__(self, other):
  69. """
  70. 定义乘方,self**other将触发该函数
  71. """
  72. assert isinstance(other, (int, float)), 'support only int or float in the exponent'
  73. # output = self ** other
  74. output = Scalar(self.value ** other, [self], f'^{other}')
  75. output.requires_grad = self.requires_grad
  76. # 计算偏导数 ∂output/∂self = other * self**(other-1)
  77. output.grad_wrt[self] = other * self.value**(other - 1)
  78. return output
  79. def sigmoid(self):
  80. """
  81. 定义sigmoid
  82. """
  83. s = 1 / (1 + math.exp(-1 * self.value))
  84. output = Scalar(s, [self], 'sigmoid')
  85. output.requires_grad = self.requires_grad
  86. # 计算偏导数 ∂output/∂self = output * (1 - output)
  87. output.grad_wrt[self] = s * (1 - s)
  88. return output
  89. def __rsub__(self, other):
  90. """
  91. 定义右减法,other - self将触发该函数
  92. """
  93. if not isinstance(other, Scalar):
  94. other = Scalar(other, requires_grad=False)
  95. output = Scalar(other.value - self.value, [self, other], '-')
  96. output.requires_grad = self.requires_grad or other.requires_grad
  97. output.grad_wrt[self] = -1
  98. output.grad_wrt[other] = 1
  99. return output
  100. def __radd__(self, other):
  101. """
  102. 定义右加法,other + self将触发该函数
  103. """
  104. return self.__add__(other)
  105. def __rmul__(self, other):
  106. """
  107. 定义右乘法,other * self将触发该函数
  108. """
  109. return self * other
  110. def backward(self, fn=None):
  111. """
  112. 由当前节点出发,求解以当前节点为顶点的计算图中每个节点的梯度,i.e. ∂self/∂node
  113. 参数
  114. ----
  115. fn :画图函数,如果该变量不等于None,则会返回向后传播每一步的计算的记录
  116. 返回
  117. ----
  118. re :向后传播每一步的计算的记录
  119. """
  120. def _topological_order():
  121. """
  122. 利用深度优先算法,返回计算图的拓扑排序(topological sorting)
  123. """
  124. def _add_prevs(node):
  125. if node not in visited:
  126. visited.add(node)
  127. for prev in node.prevs:
  128. _add_prevs(prev)
  129. ordered.append(node)
  130. ordered, visited = [], set()
  131. _add_prevs(self)
  132. return ordered
  133. def _compute_grad_of_prevs(node):
  134. """
  135. 由node节点出发,向后传播
  136. """
  137. # 作图需要,实际上对计算没有作用
  138. node.back_prop = dict()
  139. # 得到当前节点在计算图中的梯度。由于一个节点可以在多个计算图中出现,
  140. # 使用cg_grad记录当前计算图的梯度
  141. dnode = cg_grad[node]
  142. # 使用node.grad记录节点的累积梯度
  143. node.grad += dnode
  144. for prev in node.prevs:
  145. # 由于node节点的偏导数已经计算完成,可以向后传播
  146. # 需要注意的是,向后传播到上游节点是累加关系
  147. grad_spread = dnode * node.grad_wrt[prev]
  148. cg_grad[prev] = cg_grad.get(prev, 0.0) + grad_spread
  149. node.back_prop[prev] = node.back_prop.get(prev, 0.0) + grad_spread
  150. # 当前节点的偏导数等于1,因为∂self/∂self = 1。这是反向传播算法的起点
  151. cg_grad = {self: 1}
  152. # 为了计算每个节点的偏导数,需要使用拓扑排序的倒序来遍历计算图
  153. ordered = reversed(_topological_order())
  154. re = []
  155. for node in ordered:
  156. _compute_grad_of_prevs(node)
  157. # 作图需要,实际上对计算没有作用
  158. if fn is not None:
  159. re.append(fn(self, 'backward'))
  160. return re
  161. def _get_node_attr(node, direction='forward'):
  162. """
  163. 节点的属性
  164. """
  165. node_type = _get_node_type(node)
  166. def _forward_attr():
  167. if node_type == 'param':
  168. node_text = f'{{ grad=None | value={node.value: .2f} | {node.label}}}'
  169. return dict(label=node_text, shape='record', fontsize='10', fillcolor='springgreen', style='filled, bold')
  170. elif node_type == 'computation':
  171. node_text = f'{{ grad=None | value={node.value: .2f} | {node.op}}}'
  172. return dict(label=node_text, shape='record', fontsize='10', fillcolor='gray94', style='filled, rounded')
  173. elif node_type == 'input':
  174. if node.label == '':
  175. node_text = f'input={node.value: .2f}'
  176. else:
  177. node_text = f'{node.label}={node.value: .2f}'
  178. return dict(label=node_text, shape='oval', fontsize='10')
  179. def _backward_attr():
  180. attr = _forward_attr()
  181. attr['label'] = attr['label'].replace('grad=None', f'grad={node.grad: .2f}')
  182. if not node.requires_grad:
  183. attr['style'] = 'dashed'
  184. # 如果向后传播的梯度要么等于0,要么传给不需要梯度的节点,那么该节点用虚线表示
  185. grad_back = [v if k.requires_grad else 0 for (k, v) in node.back_prop.items()]
  186. if len(grad_back) > 0 and sum(grad_back) == 0:
  187. attr['style'] = 'dashed'
  188. return attr
  189. if direction == 'forward':
  190. return _forward_attr()
  191. else:
  192. return _backward_attr()
  193. def _get_node_type(node):
  194. """
  195. 决定节点的类型,计算节点、参数以及输入数据
  196. """
  197. if node.op is not None:
  198. return 'computation'
  199. if node.requires_grad:
  200. return 'param'
  201. return 'input'
  202. def _trace(root):
  203. """
  204. 遍历图中的所有点和边
  205. """
  206. nodes, edges = set(), set()
  207. def _build(v):
  208. if v not in nodes:
  209. nodes.add(v)
  210. for prev in v.prevs:
  211. edges.add((prev, v))
  212. _build(prev)
  213. _build(root)
  214. return nodes, edges
  215. def _draw_node(graph, node, direction='forward'):
  216. """
  217. 画节点
  218. """
  219. node_attr = _get_node_attr(node, direction)
  220. uid = str(id(node)) + direction
  221. graph.node(name=uid, **node_attr)
  222. def _draw_edge(graph, n1, n2, direction='forward'):
  223. """
  224. 画边
  225. """
  226. uid1 = str(id(n1)) + direction
  227. uid2 = str(id(n2)) + direction
  228. def _draw_back_edge():
  229. if n1.requires_grad and n2.requires_grad:
  230. grad = n2.back_prop.get(n1, None)
  231. if grad is None:
  232. graph.edge(uid2, uid1, arrowhead='none', color='deepskyblue')
  233. elif grad == 0:
  234. graph.edge(uid2, uid1, style='dashed', label=f'{grad: .2f}', color='deepskyblue')
  235. else:
  236. graph.edge(uid2, uid1, label=f'{grad: .2f}', color='deepskyblue')
  237. else:
  238. graph.edge(uid2, uid1, style='dashed', arrowhead='none', color='deepskyblue')
  239. if direction == 'forward':
  240. graph.edge(uid1, uid2)
  241. elif direction == 'backward':
  242. _draw_back_edge()
  243. else:
  244. _draw_back_edge()
  245. graph.edge(uid1, uid2)
  246. def draw_graph(root, direction='forward'):
  247. """
  248. 图形化展示由root为顶点的计算图
  249. 参数
  250. ----
  251. root :Scalar,计算图的顶点
  252. direction :str,向前传播(forward)或者反向传播(backward)
  253. 返回
  254. ----
  255. re :Digraph,计算图
  256. """
  257. nodes, edges = _trace(root)
  258. rankdir = 'BT' if direction == 'forward' else 'TB'
  259. graph = Digraph(format='svg', graph_attr={'rankdir': rankdir})
  260. for item in nodes:
  261. _draw_node(graph, item, direction)
  262. for n1, n2 in edges:
  263. _draw_edge(graph, n1, n2, direction)
  264. return graph