# -*- coding: UTF-8 -*- """ 此脚本用于实现线性回归模型的统计分析 """ # 保证脚本与Python2兼容 from __future__ import print_function import os import statsmodels.api as sm from statsmodels.sandbox.regression.predstd import wls_prediction_std import matplotlib.pyplot as plt import pandas as pd def read_data(path): """ 使用pandas读取数据 参数 ---- path: String,数据的路径 返回 ---- data: DataFrame,建模数据 """ data = pd.read_csv(path) return data def train_model(x, y): """ 利用训练数据,估计模型参数 参数 ---- x: DataFrame,特征 y: DataFrame,标签 返回 ---- res : RegressionResults, 训练好的线性模型 """ # 创建一个线性回归模型 model = sm.OLS(y, x) # 训练模型,估计模型参数 res = model.fit() return res def model_summary(res): """ 分析线性回归模型的统计性质 """ # 整体统计分析结果 print(res.summary()) # 用f test检测x对应的系数a是否显著 print("检验假设x的系数等于0:") print(res.f_test("x=0")) # 用f test检测常量b是否显著 print("检测假设const的系数等于0:") print(res.f_test("const=0")) # 用f test检测a=1, b=0同时成立的显著性 print("检测假设x的系数等于1和const的系数等于0同时成立:") print(res.f_test(["x=1", "const=0"])) def get_prediction(res, x): """ 得到模型的预测结果以及结果的上下限 """ prstd, ci_low, ci_up = wls_prediction_std(res, alpha=0.05) pred = res.predict(x) return pd.DataFrame({"ci_low": ci_low, "pred": pred, "ci_up": ci_up}) def visualize_model(pred, x, y): """ 模型可视化 """ # 创建一个图形框 fig = plt.figure(figsize=(6, 6), dpi=80) # 在图形框里只画一幅图 ax = fig.add_subplot(111) ax.set_xlabel('$x$') ax.set_ylabel('$y$') # 画点图,用蓝色圆点表示原始数据 ax.scatter(x, y, color='b') # 将模型预测结果画在图上 ax.plot(x, pred["pred"], "r", label="prediction") # 将预测结果置信区间的上下限画在图上 ax.plot(x, pred["ci_low"], "r--", label="95% confidence interval") ax.plot(x, pred["ci_up"], "r--", label="") plt.legend(shadow=True) # 展示上面所画的图片。图片将阻断程序的运行,直至所有的图片被关闭 # 在Python shell里面,可以设置参数"block=False",使阻断失效。 plt.show() def run_model(data): """ 线性回归模型统计建模步骤展示 """ features = ["x"] labels = ["y"] # 加入常量变量 X = sm.add_constant(data[features]) # 构建模型 res = train_model(X, data[labels]) # 分析模型效果 model_summary(res) # 得到模型的预测结果 pred = get_prediction(res, X) # 将模型结果可视化 visualize_model(pred, data[features], data[labels]) if __name__ == "__main__": home_path = os.path.dirname(os.path.abspath(__file__)) # Windows下的存储路径与Linux并不相同 if os.name == "nt": data_path = "%s\\simple_example.csv" % home_path else: data_path = "%s/simple_example.csv" % home_path data = read_data(data_path) run_model(data)