| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- # -*- coding: UTF-8 -*-
- """
- 此脚本用于随机生成线性回归模型的训练数据
- """
- import os
- import matplotlib.pyplot as plt
- import numpy as np
- import pandas as pd
- def generate_data():
- """
- 随机生成数据
- """
- # 规定随机数生成的种子
- np.random.seed(4889)
- # Python2和Python3的range并不兼容,所以使用list(range(10, 29))
- x = np.array([10] + list(range(10, 29)))
- error = np.round(np.random.randn(20), 2)
- y = x + error
- return pd.DataFrame({"x": x, "y": y})
- def visualize_data(data):
- """
- 数据可视化
- """
- # 创建一个图形框,在里面只画一幅图
- fig = plt.figure(figsize=(6, 6), dpi=80)
- ax = fig.add_subplot(111)
- # 设置坐标轴
- ax.set_xlabel("$x$")
- ax.set_xticks(range(10, 31, 5))
- ax.set_ylabel("$y$")
- ax.set_yticks(range(10, 31, 5))
- # 画点图,点的颜色为蓝色
- ax.scatter(data.x, data.y, color="b",
- label="$y = x + \epsilon$")
- plt.legend(shadow=True)
- # 展示上面所画的图片。图片将阻断程序的运行,直至所有的图片被关闭
- # 在Python shell里面,可以设置参数"block=False",使阻断失效。
- plt.show()
- if __name__ == "__main__":
- data = generate_data()
- home_path = os.path.dirname(os.path.abspath(__file__))
- # 存储数据,Windows下的存储路径与Linux并不相同
- if os.name == "nt":
- data.to_csv("%s\\simple_example.csv" % home_path, index=False)
- else:
- data.to_csv("%s/simple_example.csv" % home_path, index=False)
- visualize_data(data)
|