utils.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. # -*- coding: UTF-8 -*-
  2. """
  3. 此脚本用于定义游戏以及相应的可视化工具
  4. """
  5. import matplotlib.pyplot as plt
  6. import torch
  7. import pandas as pd
  8. class Lottery:
  9. def __init__(self):
  10. self.params = {
  11. 'w': (1, 1),
  12. 'l': (-1, 1)
  13. }
  14. def reset(self):
  15. self.state = 'w' if torch.randn(1).item() > 0 else 'l'
  16. return self.state
  17. def step(self, action):
  18. # 如果状态是t,则终止游戏
  19. if self.state == 't':
  20. return self.state, 0
  21. # 1表示抽奖; 0表示终止
  22. center, std = self.params[self.state]
  23. if action == 0:
  24. self.state = 't'
  25. return 't', 0
  26. else:
  27. reward = torch.normal(center, std, (1,)).item()
  28. # 有10%的概率终止游戏
  29. if torch.rand(1).item() < 0.01:
  30. self.state = 't'
  31. return self.state, reward
  32. def plot_values(v):
  33. # 为在Matplotlib中显示中文,设置特殊字体
  34. plt.rcParams["font.sans-serif"] = ["SimHei"]
  35. # 正确显示负号
  36. plt.rcParams['axes.unicode_minus'] = False
  37. plt.rcParams.update({'font.size': 13})
  38. # 创建一个图形框
  39. fig = plt.figure(figsize=(6, 6), dpi=100)
  40. v = pd.DataFrame(v)
  41. for k in v:
  42. v[k].plot(label=k, legend=True)
  43. legend = plt.legend(shadow=True, loc="best", fontsize=20)
  44. plt.yticks(range(-10, 11, 4))
  45. return fig
  46. def plot_action_probs(v):
  47. # 为在Matplotlib中显示中文,设置特殊字体
  48. plt.rcParams["font.sans-serif"] = ["SimHei"]
  49. # 正确显示负号
  50. plt.rcParams['axes.unicode_minus'] = False
  51. plt.rcParams.update({'font.size': 13})
  52. # 创建一个图形框
  53. fig = plt.figure(figsize=(6, 6), dpi=100)
  54. v = pd.DataFrame(v)
  55. for k in v:
  56. v[k].apply(lambda x: x[1]).plot(label=k, legend=True)
  57. legend = plt.legend(shadow=True, loc="best", fontsize=20)
  58. return fig