params_ci.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # -*- coding: UTF-8 -*-
  2. """
  3. 此脚本用于展示如何正确理解参数的置信区间
  4. """
  5. import numpy as np
  6. import statsmodels.api as sm
  7. import matplotlib.pyplot as plt
  8. import pandas as pd
  9. def generate_data():
  10. """
  11. 随机生成数据
  12. """
  13. # Python2和Python3的range并不兼容,所以使用list(range(xx, xx))
  14. x = np.array(list(range(0, 100)))
  15. error = np.round(np.random.randn(100), 2)
  16. y = x + error
  17. return pd.DataFrame({"x": x, "y": y})
  18. def train_model(x, y):
  19. """
  20. 利用训练数据,估计模型参数
  21. """
  22. # 创建一个线性回归模型
  23. model = sm.OLS(y, x)
  24. # 训练模型,估计模型参数
  25. re = model.fit()
  26. return re
  27. def visualize_ci(ci):
  28. """
  29. 可视化参数a估计值置信区间的分布
  30. """
  31. # 创建一个图形框
  32. fig = plt.figure(figsize=(6, 6), dpi=80)
  33. # 在图形框里只画一幅图
  34. ax = fig.add_subplot(111)
  35. # 将每一个95%置信区间用竖线表示
  36. for i in range(len(ci) - 1):
  37. ci_low = ci[i][0]
  38. ci_up = ci[i][1]
  39. # 如果置信区间不包含1,则用红色表示,否则用蓝色表示
  40. include_one = (ci_low < 1) & (ci_up > 1)
  41. colors = "b" if include_one else "r"
  42. ax.vlines(x=i + 1, ymin=ci_low, ymax=ci_up, colors=colors)
  43. # 用黑线将真实值1表示出来
  44. ax.hlines(1, xmin=0, xmax=len(ci))
  45. plt.show()
  46. def run():
  47. """
  48. 产生“结构”相似的随机数据,并它训练线性回归模型,得到模型参数的置信区间
  49. 以此展示模型参数估计值置信区间的真实含义
  50. """
  51. features = ["x"]
  52. label = ["y"]
  53. ci = []
  54. # 循环运行100次
  55. for i in range(100):
  56. data = generate_data()
  57. X = sm.add_constant(data[features])
  58. re = train_model(X, data[label])
  59. # 记录每一次参数a的95置信区间
  60. ci.append(re.conf_int(alpha=0.05).loc["x"].values)
  61. visualize_ci(ci)
  62. if __name__ == "__main__":
  63. run()