linear_ml.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # -*- coding: UTF-8 -*-
  2. """
  3. 此脚本用于实现线性回归模型
  4. """
  5. import os
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. import pandas as pd
  9. from sklearn import linear_model
  10. def read_data(path):
  11. """
  12. 使用pandas读取数据
  13. 参数
  14. ----
  15. path: String,数据的路径
  16. 返回
  17. ----
  18. data: DataFrame,建模数据
  19. """
  20. data = pd.read_csv(path)
  21. return data
  22. def train_model(x, y):
  23. """
  24. 利用训练数据,估计模型参数
  25. 参数
  26. ----
  27. x: DataFrame,特征
  28. y: DataFrame,标签
  29. 返回
  30. ----
  31. model : LinearRegression, 训练好的线性模型
  32. """
  33. # 创建一个线性回归模型
  34. model = linear_model.LinearRegression()
  35. # 训练模型,估计模型参数
  36. model.fit(x, y)
  37. return model
  38. def evaluate_model(model, x, y):
  39. """
  40. 计算线性模型的均方差和决定系数
  41. 参数
  42. ----
  43. model : LinearRegression, 训练完成的线性模型
  44. x: DataFrame,特征
  45. y: DataFrame,标签
  46. 返回
  47. ----
  48. mse : np.float64,均方差
  49. score : np.float64,决定系数
  50. """
  51. # 均方差(The mean squared error),均方差越小越好
  52. mse = np.mean(
  53. (model.predict(x) - y) ** 2)
  54. # 决定系数(Coefficient of determination),决定系数越接近1越好
  55. score = model.score(x, y)
  56. return mse, score
  57. def visualize_model(model, x, y):
  58. """
  59. 模型可视化
  60. """
  61. # 创建一个图形框
  62. fig = plt.figure(figsize=(6, 6), dpi=80)
  63. # 在图形框里只画一幅图
  64. ax = fig.add_subplot(111)
  65. ax.set_xlabel('$x$')
  66. ax.set_ylabel('$y$')
  67. # 画点图,用蓝色圆点表示原始数据
  68. ax.scatter(x, y, color='b')
  69. # 根据截距的正负,打印不同的标签
  70. ax.plot(x, model.predict(x), color='r',
  71. label=u'$y = %.3fx$ + %.3f' % (model.coef_, model.intercept_))
  72. plt.legend(shadow=True)
  73. # 展示上面所画的图片。图片将阻断程序的运行,直至所有的图片被关闭
  74. # 在Python shell里面,可以设置参数"block=False",使阻断失效。
  75. plt.show()
  76. def run_model(data):
  77. """
  78. 线性回归模型建模步骤展示
  79. 参数
  80. ----
  81. data : DataFrame,建模数据
  82. """
  83. features = ["x"]
  84. label = ["y"]
  85. # 产生并训练模型
  86. model = train_model(data[features], data[label])
  87. # 评价模型效果
  88. mse, score = evaluate_model(model, data[features], data[label])
  89. print("MSE is %f" % mse)
  90. print("R2 is %f" % score)
  91. # 图形化模型结果
  92. visualize_model(model, data[features], data[label])
  93. if __name__ == "__main__":
  94. home_path = os.path.dirname(os.path.abspath(__file__))
  95. # Windows下的存储路径与Linux并不相同
  96. if os.name == "nt":
  97. data_path = "%s\\simple_example.csv" % home_path
  98. else:
  99. data_path = "%s/simple_example.csv" % home_path
  100. data = read_data(data_path)
  101. run_model(data)