generate_data.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. # -*- coding: UTF-8 -*-
  2. """
  3. 此脚本用于随机生成线性回归模型的训练数据
  4. """
  5. import os
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8. import pandas as pd
  9. def generate_data():
  10. """
  11. 随机生成数据
  12. """
  13. # 规定随机数生成的种子
  14. np.random.seed(4889)
  15. # Python2和Python3的range并不兼容,所以使用list(range(10, 29))
  16. x = np.array([10] + list(range(10, 29)))
  17. error = np.round(np.random.randn(20), 2)
  18. y = x + error
  19. return pd.DataFrame({"x": x, "y": y})
  20. def visualize_data(data):
  21. """
  22. 数据可视化
  23. """
  24. # 创建一个图形框,在里面只画一幅图
  25. fig = plt.figure(figsize=(6, 6), dpi=80)
  26. ax = fig.add_subplot(111)
  27. # 设置坐标轴
  28. ax.set_xlabel("$x$")
  29. ax.set_xticks(range(10, 31, 5))
  30. ax.set_ylabel("$y$")
  31. ax.set_yticks(range(10, 31, 5))
  32. # 画点图,点的颜色为蓝色
  33. ax.scatter(data.x, data.y, color="b",
  34. label="$y = x + \epsilon$")
  35. plt.legend(shadow=True)
  36. # 展示上面所画的图片。图片将阻断程序的运行,直至所有的图片被关闭
  37. # 在Python shell里面,可以设置参数"block=False",使阻断失效。
  38. plt.show()
  39. if __name__ == "__main__":
  40. data = generate_data()
  41. home_path = os.path.dirname(os.path.abspath(__file__))
  42. # 存储数据,Windows下的存储路径与Linux并不相同
  43. if os.name == "nt":
  44. data.to_csv("%s\\simple_example.csv" % home_path, index=False)
  45. else:
  46. data.to_csv("%s/simple_example.csv" % home_path, index=False)
  47. visualize_data(data)