{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from utils import Scalar, draw_graph\n",
"from linear_model import Linear, mse"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 定义训练数据\n",
"x1 = Scalar(1.5, label='x1', requires_grad=False)\n",
"y1 = Scalar(1.0, label='y1', requires_grad=False)\n",
"x2 = Scalar(2.0, label='x2', requires_grad=False)\n",
"y2 = Scalar(4.0, label='y2', requires_grad=False)\n",
"# 反向传播\n",
"model = Linear()\n",
"loss = mse([model.error(x1, y1), model.error(x2, y2)])\n",
"loss.backward()\n",
"draw_graph(loss, 'backward')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 梯度累积\n",
"model = Linear()\n",
"# 使用x1,y1传播一次\n",
"# 系数0.5是因为梯度累积2次\n",
"loss = 0.5 * mse([model.error(x1, y1)])\n",
"loss.backward()\n",
"draw_graph(loss, 'backward')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 使用x2,y2传播一次\n",
"loss = 0.5 * mse([model.error(x2, y2)])\n",
"loss.backward()\n",
"draw_graph(loss, 'backward')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"# 固定随机种子,使得运行结果可以稳定复现\n",
"torch.manual_seed(1024)\n",
"# 产生训练用的数据\n",
"x_origin = torch.linspace(100, 300, 200)\n",
"# 将变量X归一化,否则梯度下降法很容易不稳定\n",
"x = (x_origin - torch.mean(x_origin)) / torch.std(x_origin)\n",
"epsilon = torch.randn(x.shape)\n",
"y = 10 * x + 5 + epsilon"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Step 4, Result: y = 3.12 * x + -1.99\n",
"Step 8, Result: y = 3.48 * x + -2.28\n",
"Step 12, Result: y = 3.22 * x + -1.97\n",
"Step 16, Result: y = 2.85 * x + -1.22\n",
"Step 20, Result: y = 2.68 * x + -0.23\n",
"Step 24, Result: y = 2.92 * x + 1.08\n",
"Step 28, Result: y = 3.74 * x + 2.61\n",
"Step 32, Result: y = 5.07 * x + 4.15\n",
"Step 36, Result: y = 6.73 * x + 5.52\n",
"Step 40, Result: y = 8.22 * x + 6.48\n",
"Step 44, Result: y = 9.36 * x + 5.75\n",
"Step 48, Result: y = 9.75 * x + 5.42\n",
"Step 52, Result: y = 9.88 * x + 5.28\n",
"Step 56, Result: y = 9.89 * x + 5.26\n",
"Step 60, Result: y = 9.89 * x + 5.20\n",
"Step 64, Result: y = 9.88 * x + 5.18\n",
"Step 68, Result: y = 9.88 * x + 5.17\n",
"Step 72, Result: y = 9.84 * x + 5.14\n",
"Step 76, Result: y = 9.86 * x + 5.15\n",
"Step 80, Result: y = 9.94 * x + 5.21\n"
]
}
],
"source": [
"# 生成模型\n",
"model = Linear()\n",
"# 定义每批次用到的数据量\n",
"batch_size = 20\n",
"# 定义每批次梯度累积的次数\n",
"gradient_accumulation_iter = 4\n",
"# 每次反向传播的数据量\n",
"micro_size = int(batch_size / gradient_accumulation_iter)\n",
"learning_rate = 0.1\n",
"\n",
"for t in range(20 * gradient_accumulation_iter):\n",
" # 选取当前批次的数据,用于训练模型\n",
" ix = (t * micro_size) % len(x)\n",
" xx = x[ix: ix + micro_size]\n",
" yy = y[ix: ix + micro_size]\n",
" # 计算当前批次数据的损失\n",
" loss = mse([model.error(_x, _y) for _x, _y in zip(xx, yy)])\n",
" # 根据梯度累积的次数,调整模型损失的权重\n",
" loss *= 1 / gradient_accumulation_iter\n",
" # 计算损失函数的梯度\n",
" loss.backward()\n",
" if (t + 1) % gradient_accumulation_iter == 0:\n",
" # 迭代更新模型参数的估计值\n",
" model.a -= learning_rate * model.a.grad\n",
" model.b -= learning_rate * model.b.grad\n",
" # 将使用完的梯度清零\n",
" model.a.grad = 0.0\n",
" model.b.grad = 0.0\n",
" print(f'Step {t + 1}, Result: {model.string()}')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}