linear_stats.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # -*- coding: UTF-8 -*-
  2. """
  3. 此脚本用于实现线性回归模型的统计分析
  4. """
  5. # 保证脚本与Python2兼容
  6. from __future__ import print_function
  7. import os
  8. import statsmodels.api as sm
  9. from statsmodels.sandbox.regression.predstd import wls_prediction_std
  10. import matplotlib.pyplot as plt
  11. import pandas as pd
  12. def read_data(path):
  13. """
  14. 使用pandas读取数据
  15. 参数
  16. ----
  17. path: String,数据的路径
  18. 返回
  19. ----
  20. data: DataFrame,建模数据
  21. """
  22. data = pd.read_csv(path)
  23. return data
  24. def train_model(x, y):
  25. """
  26. 利用训练数据,估计模型参数
  27. 参数
  28. ----
  29. x: DataFrame,特征
  30. y: DataFrame,标签
  31. 返回
  32. ----
  33. res : RegressionResults, 训练好的线性模型
  34. """
  35. # 创建一个线性回归模型
  36. model = sm.OLS(y, x)
  37. # 训练模型,估计模型参数
  38. res = model.fit()
  39. return res
  40. def model_summary(res):
  41. """
  42. 分析线性回归模型的统计性质
  43. """
  44. # 整体统计分析结果
  45. print(res.summary())
  46. # 用f test检测x对应的系数a是否显著
  47. print("检验假设x的系数等于0:")
  48. print(res.f_test("x=0"))
  49. # 用f test检测常量b是否显著
  50. print("检测假设const的系数等于0:")
  51. print(res.f_test("const=0"))
  52. # 用f test检测a=1, b=0同时成立的显著性
  53. print("检测假设x的系数等于1和const的系数等于0同时成立:")
  54. print(res.f_test(["x=1", "const=0"]))
  55. def get_prediction(res, x):
  56. """
  57. 得到模型的预测结果以及结果的上下限
  58. """
  59. prstd, ci_low, ci_up = wls_prediction_std(res, alpha=0.05)
  60. pred = res.predict(x)
  61. return pd.DataFrame({"ci_low": ci_low, "pred": pred, "ci_up": ci_up})
  62. def visualize_model(pred, x, y):
  63. """
  64. 模型可视化
  65. """
  66. # 创建一个图形框
  67. fig = plt.figure(figsize=(6, 6), dpi=80)
  68. # 在图形框里只画一幅图
  69. ax = fig.add_subplot(111)
  70. ax.set_xlabel('$x$')
  71. ax.set_ylabel('$y$')
  72. # 画点图,用蓝色圆点表示原始数据
  73. ax.scatter(x, y, color='b')
  74. # 将模型预测结果画在图上
  75. ax.plot(x, pred["pred"], "r", label="prediction")
  76. # 将预测结果置信区间的上下限画在图上
  77. ax.plot(x, pred["ci_low"], "r--", label="95% confidence interval")
  78. ax.plot(x, pred["ci_up"], "r--", label="")
  79. plt.legend(shadow=True)
  80. # 展示上面所画的图片。图片将阻断程序的运行,直至所有的图片被关闭
  81. # 在Python shell里面,可以设置参数"block=False",使阻断失效。
  82. plt.show()
  83. def run_model(data):
  84. """
  85. 线性回归模型统计建模步骤展示
  86. """
  87. features = ["x"]
  88. labels = ["y"]
  89. # 加入常量变量
  90. X = sm.add_constant(data[features])
  91. # 构建模型
  92. res = train_model(X, data[labels])
  93. # 分析模型效果
  94. model_summary(res)
  95. # 得到模型的预测结果
  96. pred = get_prediction(res, X)
  97. # 将模型结果可视化
  98. visualize_model(pred, data[features], data[labels])
  99. if __name__ == "__main__":
  100. home_path = os.path.dirname(os.path.abspath(__file__))
  101. # Windows下的存储路径与Linux并不相同
  102. if os.name == "nt":
  103. data_path = "%s\\simple_example.csv" % home_path
  104. else:
  105. data_path = "%s/simple_example.csv" % home_path
  106. data = read_data(data_path)
  107. run_model(data)