linear_model.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. # -*- coding: UTF-8 -*-
  2. """
  3. 此脚本用于定义线性回归模型
  4. """
  5. from utils import Scalar
  6. def mse(errors):
  7. """
  8. 计算均方误差
  9. """
  10. n = len(errors)
  11. wrt = {}
  12. value = 0.0
  13. requires_grad = False
  14. for item in errors:
  15. value += item.value ** 2 / n
  16. wrt[item] = 2 / n * item.value
  17. requires_grad = requires_grad or item.requires_grad
  18. output = Scalar(value, errors, 'mse')
  19. output.requires_grad=requires_grad
  20. output.grad_wrt = wrt
  21. return output
  22. class Linear:
  23. def __init__(self):
  24. """
  25. 定义线性回归模型的参数:a, b
  26. """
  27. self.a = Scalar(0.0, label='a')
  28. self.b = Scalar(0.0, label='b')
  29. def forward(self, x):
  30. """
  31. 根据当前的参数估计值,得到模型的预测结果
  32. """
  33. return self.a * x + self.b
  34. def error(self, x, y):
  35. """
  36. 当前数据的模型误差
  37. """
  38. return y - self.forward(x)
  39. def string(self):
  40. """
  41. 输出当前模型的结果
  42. """
  43. return f'y = {self.a.value:.2f} * x + {self.b.value:.2f}'