linear_model.py 1.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. # -*- coding: UTF-8 -*-
  2. '''
  3. 此脚本用于定义线性回归模型
  4. '''
  5. from utils import Scalar
  6. def mse(errors):
  7. '''
  8. 计算均方误差
  9. '''
  10. n = len(errors)
  11. wrt = {}
  12. value = 0.0
  13. requires_grad = False
  14. for item in errors:
  15. value += item.value ** 2 / n
  16. wrt[item] = 2 / n * item.value
  17. requires_grad = requires_grad or item.requires_grad
  18. output = Scalar(value, errors, 'mse')
  19. output.requires_grad=requires_grad
  20. output.grad_wrt = wrt
  21. return output
  22. class Linear:
  23. def __init__(self):
  24. '''
  25. 定义线性回归模型的参数:a, b
  26. '''
  27. self.a = Scalar(0.0, label='a')
  28. self.b = Scalar(0.0, label='b')
  29. def forward(self, x):
  30. '''
  31. 根据当前的参数估计值,得到模型的预测结果
  32. '''
  33. return self.a * x + self.b
  34. def error(self, x, y):
  35. '''
  36. 当前数据的模型误差
  37. '''
  38. return y - self.forward(x)
  39. def string(self):
  40. '''
  41. 输出当前模型的结果
  42. '''
  43. return f'y = {self.a.value:.2f} * x + {self.b.value:.2f}'