utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. # -*- coding: UTF-8 -*-
  2. """
  3. 此脚本用于定义Scalar类,以及相应的可视化工具
  4. """
  5. from graphviz import Digraph
  6. class Scalar:
  7. def __init__(self, value, prevs=[], op=None, label='', requires_grad=True):
  8. # 节点的值
  9. self.value = value
  10. # 节点的标识(label)和对应的运算(op),用于作图
  11. self.label = label
  12. self.op = op
  13. # 节点的前节点
  14. self.prevs = prevs
  15. # 是否需要计算该节点偏导数,即∂loss/∂self
  16. self.requires_grad = requires_grad
  17. # 存储该节点偏导数,即∂loss/∂self
  18. self.grad = 0.0
  19. # 如果该节点的prevs非空,存储所有的∂self/∂prev
  20. self.grad_wrt = dict()
  21. # 作图需要,实际上对计算没有作用
  22. self.back_prop = dict()
  23. def __repr__(self):
  24. return f'Scalar(value={self.value:.2f}, grad={self.grad:.2f})'
  25. def __add__(self, other):
  26. """
  27. 定义加法,self + other将触发该函数
  28. """
  29. if not isinstance(other, Scalar):
  30. other = Scalar(other, requires_grad=False)
  31. # output = self + other
  32. output = Scalar(self.value + other.value, [self, other], '+')
  33. output.requires_grad = self.requires_grad or other.requires_grad
  34. # 计算偏导数 ∂output/∂self = 1
  35. output.grad_wrt[self] = 1
  36. # 计算偏导数 ∂output/∂other = 1
  37. output.grad_wrt[other] = 1
  38. return output
  39. def __sub__(self, other):
  40. """
  41. 定义减法,self - other将触发该函数
  42. """
  43. if not isinstance(other, Scalar):
  44. other = Scalar(other, requires_grad=False)
  45. # output = self - other
  46. output = Scalar(self.value - other.value, [self, other], '-')
  47. output.requires_grad = self.requires_grad or other.requires_grad
  48. # 计算偏导数 ∂output/∂self = 1
  49. output.grad_wrt[self] = 1
  50. # 计算偏导数 ∂output/∂other = -1
  51. output.grad_wrt[other] = -1
  52. return output
  53. def __mul__(self, other):
  54. """
  55. 定义乘法,self * other将触发该函数
  56. """
  57. if not isinstance(other, Scalar):
  58. other = Scalar(other, requires_grad=False)
  59. # output = self * other
  60. output = Scalar(self.value * other.value, [self, other], '*')
  61. output.requires_grad = self.requires_grad or other.requires_grad
  62. # 计算偏导数 ∂output/∂self = other
  63. output.grad_wrt[self] = other.value
  64. # 计算偏导数 ∂output/∂other = self
  65. output.grad_wrt[other] = self.value
  66. return output
  67. def __pow__(self, other):
  68. """
  69. 定义乘方,self**other将触发该函数
  70. """
  71. assert isinstance(other, (int, float)), 'support only int or float in the exponent'
  72. # output = self ** other
  73. output = Scalar(self.value ** other, [self], f'^{other}')
  74. output.requires_grad = self.requires_grad
  75. # 计算偏导数 ∂output/∂self = other * self**(other-1)
  76. output.grad_wrt[self] = other * self.value**(other - 1)
  77. return output
  78. def sigmoid(self):
  79. """
  80. 定义sigmoid
  81. """
  82. s = 1 / (1 + math.exp(-1 * self.value))
  83. output = Scalar(s, [self], 'sigmoid')
  84. output.requires_grad = self.requires_grad
  85. # 计算偏导数 ∂output/∂self = output * (1 - output)
  86. output.grad_wrt[self] = s * (1 - s)
  87. return output
  88. def __rsub__(self, other):
  89. """
  90. 定义右减法,other - self将触发该函数
  91. """
  92. if not isinstance(other, Scalar):
  93. other = Scalar(other, requires_grad=False)
  94. output = Scalar(other.value - self.value, [self, other], '-')
  95. output.requires_grad = self.requires_grad or other.requires_grad
  96. output.grad_wrt[self] = -1
  97. output.grad_wrt[other] = 1
  98. return output
  99. def __radd__(self, other):
  100. """
  101. 定义右加法,other + self将触发该函数
  102. """
  103. return self.__add__(other)
  104. def __rmul__(self, other):
  105. """
  106. 定义右乘法,other * self将触发该函数
  107. """
  108. return self * other
  109. def backward(self, fn=None):
  110. """
  111. 由当前节点出发,求解以当前节点为顶点的计算图中每个节点的梯度,i.e. ∂self/∂node
  112. 参数
  113. ----
  114. fn :画图函数,如果该变量不等于None,则会返回向后传播每一步的计算的记录
  115. 返回
  116. ----
  117. re :向后传播每一步的计算的记录
  118. """
  119. def _topological_order():
  120. """
  121. 利用深度优先算法,返回计算图的拓扑排序(topological sorting)
  122. """
  123. def _add_prevs(node):
  124. if node not in visited:
  125. visited.add(node)
  126. for prev in node.prevs:
  127. _add_prevs(prev)
  128. ordered.append(node)
  129. ordered, visited = [], set()
  130. _add_prevs(self)
  131. return ordered
  132. def _compute_grad_of_prevs(node):
  133. """
  134. 由node节点出发,向后传播
  135. """
  136. # 作图需要,实际上对计算没有作用
  137. node.back_prop = dict()
  138. # 得到当前节点在计算图中的梯度。由于一个节点可以在多个计算图中出现,
  139. # 使用cg_grad记录当前计算图的梯度
  140. dnode = cg_grad[node]
  141. # 使用node.grad记录节点的累积梯度
  142. node.grad += dnode
  143. for prev in node.prevs:
  144. # 由于node节点的偏导数已经计算完成,可以向后传播
  145. # 需要注意的是,向后传播到上游节点是累加关系
  146. grad_spread = dnode * node.grad_wrt[prev]
  147. cg_grad[prev] = cg_grad.get(prev, 0.0) + grad_spread
  148. node.back_prop[prev] = node.back_prop.get(prev, 0.0) + grad_spread
  149. # 当前节点的偏导数等于1,因为∂self/∂self = 1。这是反向传播算法的起点
  150. cg_grad = {self: 1}
  151. # 为了计算每个节点的偏导数,需要使用拓扑排序的倒序来遍历计算图
  152. ordered = reversed(_topological_order())
  153. re = []
  154. for node in ordered:
  155. _compute_grad_of_prevs(node)
  156. # 作图需要,实际上对计算没有作用
  157. if fn is not None:
  158. re.append(fn(self, 'backward'))
  159. return re
  160. def _get_node_attr(node, direction='forward'):
  161. """
  162. 节点的属性
  163. """
  164. node_type = _get_node_type(node)
  165. def _forward_attr():
  166. if node_type == 'param':
  167. node_text = f'{{ grad=None | value={node.value: .2f} | {node.label}}}'
  168. return dict(label=node_text, shape='record', fontsize='10', fillcolor='springgreen', style='filled, bold')
  169. elif node_type == 'computation':
  170. node_text = f'{{ grad=None | value={node.value: .2f} | {node.op}}}'
  171. return dict(label=node_text, shape='record', fontsize='10', fillcolor='gray94', style='filled, rounded')
  172. elif node_type == 'input':
  173. if node.label == '':
  174. node_text = f'input={node.value: .2f}'
  175. else:
  176. node_text = f'{node.label}={node.value: .2f}'
  177. return dict(label=node_text, shape='oval', fontsize='10')
  178. def _backward_attr():
  179. attr = _forward_attr()
  180. attr['label'] = attr['label'].replace('grad=None', f'grad={node.grad: .2f}')
  181. if not node.requires_grad:
  182. attr['style'] = 'dashed'
  183. # 如果向后传播的梯度要么等于0,要么传给不需要梯度的节点,那么该节点用虚线表示
  184. grad_back = [v if k.requires_grad else 0 for (k, v) in node.back_prop.items()]
  185. if len(grad_back) > 0 and sum(grad_back) == 0:
  186. attr['style'] = 'dashed'
  187. return attr
  188. if direction == 'forward':
  189. return _forward_attr()
  190. else:
  191. return _backward_attr()
  192. def _get_node_type(node):
  193. """
  194. 决定节点的类型,计算节点、参数以及输入数据
  195. """
  196. if node.op is not None:
  197. return 'computation'
  198. if node.requires_grad:
  199. return 'param'
  200. return 'input'
  201. def _trace(root):
  202. """
  203. 遍历图中的所有点和边
  204. """
  205. nodes, edges = set(), set()
  206. def _build(v):
  207. if v not in nodes:
  208. nodes.add(v)
  209. for prev in v.prevs:
  210. edges.add((prev, v))
  211. _build(prev)
  212. _build(root)
  213. return nodes, edges
  214. def _draw_node(graph, node, direction='forward'):
  215. """
  216. 画节点
  217. """
  218. node_attr = _get_node_attr(node, direction)
  219. uid = str(id(node)) + direction
  220. graph.node(name=uid, **node_attr)
  221. def _draw_edge(graph, n1, n2, direction='forward'):
  222. """
  223. 画边
  224. """
  225. uid1 = str(id(n1)) + direction
  226. uid2 = str(id(n2)) + direction
  227. def _draw_back_edge():
  228. if n1.requires_grad and n2.requires_grad:
  229. grad = n2.back_prop.get(n1, None)
  230. if grad is None:
  231. graph.edge(uid2, uid1, arrowhead='none', color='deepskyblue')
  232. elif grad == 0:
  233. graph.edge(uid2, uid1, style='dashed', label=f'{grad: .2f}', color='deepskyblue')
  234. else:
  235. graph.edge(uid2, uid1, label=f'{grad: .2f}', color='deepskyblue')
  236. else:
  237. graph.edge(uid2, uid1, style='dashed', arrowhead='none', color='deepskyblue')
  238. if direction == 'forward':
  239. graph.edge(uid1, uid2)
  240. elif direction == 'backward':
  241. _draw_back_edge()
  242. else:
  243. _draw_back_edge()
  244. graph.edge(uid1, uid2)
  245. def draw_graph(root, direction='forward'):
  246. """
  247. 图形化展示由root为顶点的计算图
  248. 参数
  249. ----
  250. root :Scalar,计算图的顶点
  251. direction :str,向前传播(forward)或者反向传播(backward)
  252. 返回
  253. ----
  254. re :Digraph,计算图
  255. """
  256. nodes, edges = _trace(root)
  257. rankdir = 'BT' if direction == 'forward' else 'TB'
  258. graph = Digraph(format='svg', graph_attr={'rankdir': rankdir})
  259. for item in nodes:
  260. _draw_node(graph, item, direction)
  261. for n1, n2 in edges:
  262. _draw_edge(graph, n1, n2, direction)
  263. return graph