utils.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # -*- coding: UTF-8 -*-
  2. '''
  3. 此脚本用于定义游戏以及相应的可视化工具
  4. '''
  5. import matplotlib.pyplot as plt
  6. import torch
  7. import pandas as pd
  8. class Lottery:
  9. def __init__(self):
  10. # 定义游戏的两个状态
  11. self.params = {
  12. 'w': (1, 1),
  13. 'l': (-1, 1)
  14. }
  15. def reset(self):
  16. self.state = 'w' if torch.randn(1).item() > 0 else 'l'
  17. return self.state
  18. def step(self, action):
  19. # 如果状态是t,则终止游戏
  20. if self.state == 't':
  21. return self.state, 0
  22. # 1表示抽奖; 0表示终止
  23. center, std = self.params[self.state]
  24. if action == 0:
  25. self.state = 't'
  26. return 't', 0
  27. else:
  28. reward = torch.normal(center, std, (1,)).item()
  29. # 有10%的概率终止游戏
  30. if torch.rand(1).item() < 0.01:
  31. self.state = 't'
  32. return self.state, reward
  33. def plot_values(v):
  34. # 为在Matplotlib中显示中文,设置特殊字体
  35. plt.rcParams['font.sans-serif'] = ['SimHei']
  36. # 正确显示负号
  37. plt.rcParams['axes.unicode_minus'] = False
  38. plt.rcParams.update({'font.size': 13})
  39. # 创建一个图形框
  40. fig = plt.figure(figsize=(6, 6), dpi=100)
  41. v = pd.DataFrame(v)
  42. for k in v:
  43. v[k].plot(label=k, legend=True)
  44. legend = plt.legend(shadow=True, loc='best', fontsize=20)
  45. plt.yticks(range(-10, 11, 4))
  46. return fig
  47. def plot_action_probs(v):
  48. # 为在Matplotlib中显示中文,设置特殊字体
  49. plt.rcParams['font.sans-serif'] = ['SimHei']
  50. # 正确显示负号
  51. plt.rcParams['axes.unicode_minus'] = False
  52. plt.rcParams.update({'font.size': 13})
  53. # 创建一个图形框
  54. fig = plt.figure(figsize=(6, 6), dpi=100)
  55. v = pd.DataFrame(v)
  56. for k in v:
  57. # 在图中画出抽奖的概率
  58. v[k].apply(lambda x: x[1]).plot(label=k, legend=True)
  59. legend = plt.legend(shadow=True, loc='best', fontsize=20)
  60. return fig