ols_vs_lad.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # -*- coding: UTF-8 -*-
  2. """
  3. 此脚本用于比较LAD线性回归和OLS线性回归
  4. """
  5. import statsmodels.api as sm
  6. from sklearn import linear_model
  7. from statsmodels.regression.quantile_regression import QuantReg
  8. import matplotlib.pyplot as plt
  9. import numpy as np
  10. import pandas as pd
  11. def generate_data():
  12. """
  13. 随机生成数据
  14. """
  15. np.random.seed(4889)
  16. # Python2和Python3的range并不兼容,所以使用list(range(10, 29))
  17. x = np.array([10] + list(range(10, 29)))
  18. error = np.round(np.random.randn(20), 2)
  19. y = x + error
  20. # 增加异常点
  21. x = np.append(x, 29)
  22. y = np.append(y, 29 * 10)
  23. return pd.DataFrame({"x": x, "y": y})
  24. def train_OLS(x, y):
  25. """
  26. 训练OLS线性回归模型,并返回模型预测值
  27. """
  28. model = linear_model.LinearRegression()
  29. model.fit(x, y)
  30. re = model.predict(x)
  31. return re
  32. def train_LAD(x, y):
  33. """
  34. 训练LAD线性回归模型,并返回模型预测值
  35. """
  36. X = sm.add_constant(x)
  37. model = QuantReg(y, X)
  38. model = model.fit(q=0.5)
  39. re = model.predict(X)
  40. return re
  41. def visualize_model(x, y, ols, lad):
  42. """
  43. 模型结果可视化
  44. """
  45. # 创建一个图形框
  46. fig = plt.figure(figsize=(6, 6), dpi=80)
  47. # 在图形框里只画一幅图
  48. ax = fig.add_subplot(111)
  49. # 设置坐标轴
  50. ax.set_xlabel("$x$")
  51. ax.set_xticks(range(10, 31, 5))
  52. ax.set_ylabel("$y$")
  53. # 画点图,点的颜色为蓝色,半透明
  54. ax.scatter(x, y, color="b", alpha=0.4)
  55. # 将模型结果可视化出来
  56. # 用红色虚线表示OLS线性回归模型的结果
  57. ax.plot(x, ols, 'r--', label="OLS")
  58. # 用黑色实线表示LAD线性回归模型的结果
  59. ax.plot(x, lad, 'k', label="LAD")
  60. plt.legend(shadow=True)
  61. # 展示上面所画的图片。图片将阻断程序的运行,直至所有的图片被关闭
  62. # 在Python shell里面,可以设置参数"block=False",使阻断失效
  63. plt.show()
  64. def OLS_vs_LAD(data):
  65. """
  66. 比较OLS模型和LAD模型的差异
  67. """
  68. features = ["x"]
  69. label = ["y"]
  70. ols = train_OLS(data[features], data[label])
  71. lad = train_LAD(data[features], data[label])
  72. visualize_model(data[features], data[label], ols, lad)
  73. if __name__ == "__main__":
  74. data = generate_data()
  75. OLS_vs_LAD(data)