|
|
@@ -47,24 +47,50 @@
|
|
|
"class LSTMCell(nn.Module):\n",
|
|
|
"\n",
|
|
|
" def __init__(self, input_size, hidden_size):\n",
|
|
|
+ " '''\n",
|
|
|
+ " 长短期记忆网络的神经元\n",
|
|
|
+ " 参数\n",
|
|
|
+ " ----\n",
|
|
|
+ " input_size :int,输入数据的特征长度\n",
|
|
|
+ " hidden_size :int,隐藏状态的特征长度\n",
|
|
|
+ " '''\n",
|
|
|
" super().__init__()\n",
|
|
|
" self.input_size = input_size\n",
|
|
|
" self.hidden_size = hidden_size\n",
|
|
|
" combined_size = self.input_size + self.hidden_size\n",
|
|
|
+ " # 定义输入门的线性部分\n",
|
|
|
" self.in_gate = nn.Linear(combined_size, self.hidden_size)\n",
|
|
|
+ " # 定义遗忘门的线性部分\n",
|
|
|
" self.forget_gate = nn.Linear(combined_size, self.hidden_size)\n",
|
|
|
+ " # 定义备选细胞状态的线性部分\n",
|
|
|
" self.new_cell_state = nn.Linear(combined_size, self.hidden_size)\n",
|
|
|
+ " # 定义输出门的线性部分\n",
|
|
|
" self.out_gate = nn.Linear(combined_size, self.hidden_size)\n",
|
|
|
"\n",
|
|
|
" def forward(self, inputs, state=None):\n",
|
|
|
- " B, _ = inputs.shape # (B, I)\n",
|
|
|
- " # state: ((B, H), (B, H))\n",
|
|
|
+ " '''\n",
|
|
|
+ " 向前传播\n",
|
|
|
+ " 参数\n",
|
|
|
+ " ----\n",
|
|
|
+ " inputs :torch.FloatTensor\n",
|
|
|
+ " 输入数据,形状为(B, I),其中B表示批量大小,I表示文字特征的长度(input_size)\n",
|
|
|
+ " state :tuple(torch.FloatTensor, torch.FloatTensor)\n",
|
|
|
+ " (隐藏状态,细胞状态),两个状态的形状都为(B, H),其中H表示隐藏状态的长度(hidden_size)\n",
|
|
|
+ " 返回\n",
|
|
|
+ " ----\n",
|
|
|
+ " hs :torch.FloatTensor,隐藏状态,形状为(B, H)\n",
|
|
|
+ " cs :torch.FloatTensor,细胞状态,形状为(B, H)\n",
|
|
|
+ " '''\n",
|
|
|
+ " B, _ = inputs.shape\n",
|
|
|
" if state is None:\n",
|
|
|
" state = self.init_state(B, inputs.device)\n",
|
|
|
" hs, cs = state\n",
|
|
|
" combined = torch.cat((inputs, hs), dim=1) # (B, I + H)\n",
|
|
|
+ " # 输入门\n",
|
|
|
" ingate = F.sigmoid(self.in_gate(combined)) # (B, H)\n",
|
|
|
+ " # 遗忘门\n",
|
|
|
" forgetgate = F.sigmoid(self.forget_gate(combined)) # (B, H)\n",
|
|
|
+ " # 输出门\n",
|
|
|
" outgate = F.sigmoid(self.out_gate(combined)) # (B, H)\n",
|
|
|
" # 更新细胞状态\n",
|
|
|
" ncs = F.tanh(self.new_cell_state(combined)) # (B, H)\n",
|
|
|
@@ -74,6 +100,7 @@
|
|
|
" return hs, cs\n",
|
|
|
"\n",
|
|
|
" def init_state(self, B, device):\n",
|
|
|
+ " # 默认的隐藏状态和细胞状态全部都等于0\n",
|
|
|
" cs = torch.zeros((B, self.hidden_size), device=device)\n",
|
|
|
" hs = torch.zeros((B, self.hidden_size), device=device)\n",
|
|
|
" return hs, cs\n",
|
|
|
@@ -81,21 +108,38 @@
|
|
|
"class LSTM(nn.Module):\n",
|
|
|
"\n",
|
|
|
" def __init__(self, input_size, hidden_size):\n",
|
|
|
+ " '''\n",
|
|
|
+ " 单层的长短期记忆网络(支持批量计算)\n",
|
|
|
+ " 参数\n",
|
|
|
+ " ----\n",
|
|
|
+ " input_size :int,输入数据的特征长度\n",
|
|
|
+ " hidden_size :int,隐藏状态的特征长度\n",
|
|
|
+ " '''\n",
|
|
|
" super().__init__()\n",
|
|
|
" self.input_size = input_size\n",
|
|
|
" self.hidden_size = hidden_size\n",
|
|
|
" self.lstm = LSTMCell(self.input_size, self.hidden_size)\n",
|
|
|
"\n",
|
|
|
" def forward(self, inputs, state=None):\n",
|
|
|
+ " '''\n",
|
|
|
+ " 向前传播\n",
|
|
|
+ " 参数\n",
|
|
|
+ " ----\n",
|
|
|
+ " inputs :torch.FloatTensor\n",
|
|
|
+ " 输入数据的集合,形状为(B, T, C),其中B表示批量大小,T表示文本长度,C表示文字特征的长度(input_size)\n",
|
|
|
+ " state :tuple(torch.FloatTensor, torch.FloatTensor)\n",
|
|
|
+ " (初始的隐藏状态,初始的细胞状态),两个状态的形状都为(B, H),其中H表示隐藏状态的长度(hidden_size)\n",
|
|
|
+ " 返回\n",
|
|
|
+ " ----\n",
|
|
|
+ " hidden :torch.FloatTensor,所有隐藏状态的集合,形状为(B, T, H)\n",
|
|
|
+ " '''\n",
|
|
|
" re = []\n",
|
|
|
- " # B batch_size,\n",
|
|
|
- " # T sequence length,\n",
|
|
|
- " # C number of channels.\n",
|
|
|
" B, T, C = inputs.shape\n",
|
|
|
" inputs = inputs.transpose(0, 1) # (T, B, C)\n",
|
|
|
" for i in range(T):\n",
|
|
|
" state = self.lstm(inputs[i], state)\n",
|
|
|
- " re.append(state[0]) # state[0]: (B, H)\n",
|
|
|
+ " # 只记录隐藏状态,state[0]的形状为(B, H)\n",
|
|
|
+ " re.append(state[0])\n",
|
|
|
" result_tensor = torch.stack(re, dim=0) # (T, B, H)\n",
|
|
|
" return result_tensor.transpose(0, 1) # (B, T, H)"
|
|
|
]
|
|
|
@@ -124,6 +168,10 @@
|
|
|
],
|
|
|
"source": [
|
|
|
"def test_lstm():\n",
|
|
|
+ " '''\n",
|
|
|
+ " 测试LSTM实现的准确性\n",
|
|
|
+ " '''\n",
|
|
|
+ " # 随机生成模型结构\n",
|
|
|
" B, T, input_size, hidden_size, num_layers = torch.randint(1, 20, (5,)).tolist()\n",
|
|
|
" ref_model = nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=True)\n",
|
|
|
" # 随机生成输入\n",
|
|
|
@@ -150,10 +198,11 @@
|
|
|
" model.lstm.new_cell_state.bias = nn.Parameter(cb)\n",
|
|
|
" model.lstm.out_gate.weight = nn.Parameter(o)\n",
|
|
|
" model.lstm.out_gate.bias = nn.Parameter(ob)\n",
|
|
|
- " # 验证计算结果\n",
|
|
|
+ " # 计算隐藏状态\n",
|
|
|
" re = model(re, (_hs[layer_index], _cs[layer_index]))\n",
|
|
|
" ref_re, _ = ref_model(inputs, (hs, cs))\n",
|
|
|
- " out = torch.all(re - ref_re < 1e-4)\n",
|
|
|
+ " # 验证计算结果(最后一层的隐藏状态是否一致)\n",
|
|
|
+ " out = torch.all(torch.abs(re - ref_re) < 1e-4)\n",
|
|
|
" return out, (B, T, input_size, hidden_size, num_layers)\n",
|
|
|
"\n",
|
|
|
"test_lstm()"
|
|
|
@@ -167,10 +216,12 @@
|
|
|
},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
+ "# 一些超参数\n",
|
|
|
"learning_rate = 1e-3\n",
|
|
|
"eval_iters = 10\n",
|
|
|
"batch_size=1000\n",
|
|
|
"sequence_len=64\n",
|
|
|
+ "# 如果有GPU,该脚本将使用GPU进行计算\n",
|
|
|
"device = 'cuda' if torch.cuda.is_available() else 'cpu'"
|
|
|
]
|
|
|
},
|
|
|
@@ -233,25 +284,49 @@
|
|
|
"class CharLSTM(nn.Module):\n",
|
|
|
"\n",
|
|
|
" def __init__(self, vs):\n",
|
|
|
+ " '''\n",
|
|
|
+ " 三层的长短期记忆网络\n",
|
|
|
+ " 参数\n",
|
|
|
+ " ----\n",
|
|
|
+ " vs :int,字典大小\n",
|
|
|
+ " '''\n",
|
|
|
" super().__init__()\n",
|
|
|
+ " # 定义文字嵌入的特征长度\n",
|
|
|
" self.emb_size = 256\n",
|
|
|
+ " # 定义隐藏状态的特征长度\n",
|
|
|
" self.hidden_size = 128\n",
|
|
|
+ " # 文字嵌入层\n",
|
|
|
" self.embedding = nn.Embedding(vs, self.emb_size)\n",
|
|
|
+ " # 随机失活\n",
|
|
|
" self.dp = nn.Dropout(0.4)\n",
|
|
|
+ " # 第一层长短期记忆网络\n",
|
|
|
" self.lstm1 = LSTM(self.emb_size, self.hidden_size)\n",
|
|
|
+ " # 层归一化\n",
|
|
|
" self.norm1 = nn.LayerNorm(self.hidden_size)\n",
|
|
|
" self.lstm2 = LSTM(self.hidden_size, self.hidden_size)\n",
|
|
|
" self.norm2 = nn.LayerNorm(self.hidden_size)\n",
|
|
|
" self.lstm3 = LSTM(self.hidden_size, self.hidden_size)\n",
|
|
|
" self.norm3 = nn.LayerNorm(self.hidden_size)\n",
|
|
|
+ " # 语言建模头,根据最后一层的隐藏状态预测下一个字母是什么\n",
|
|
|
" self.h2o = nn.Linear(self.hidden_size, vs)\n",
|
|
|
"\n",
|
|
|
" def forward(self, x):\n",
|
|
|
- " # x: (B, T)\n",
|
|
|
+ " '''\n",
|
|
|
+ " 向前传播\n",
|
|
|
+ " 参数\n",
|
|
|
+ " ----\n",
|
|
|
+ " x :torch.LongTensor,当前字母在字典中的位置,形状为(B, T)\n",
|
|
|
+ " 返回\n",
|
|
|
+ " ----\n",
|
|
|
+ " output :torch.FloatTensor,预测结果的logits,形状为(B, T, vs)\n",
|
|
|
+ " '''\n",
|
|
|
" emb = self.embedding(x) # (B, T, C)\n",
|
|
|
" h = self.norm1(self.dp(self.lstm1(emb))) # (B, T, H)\n",
|
|
|
+ " # 第一层的隐藏状态是第二层的输入\n",
|
|
|
" h = self.norm2(self.dp(self.lstm2(h))) # (B, T, H)\n",
|
|
|
+ " # 第二层的隐藏状态是第三层的输入\n",
|
|
|
" h = self.norm3(self.dp(self.lstm3(h))) # (B, T, H)\n",
|
|
|
+ " # 使用第三层的隐藏状态预测下一个字母是什么\n",
|
|
|
" output = self.h2o(h) # (B, T, vs)\n",
|
|
|
" return output\n",
|
|
|
"\n",
|
|
|
@@ -312,6 +387,7 @@
|
|
|
}
|
|
|
],
|
|
|
"source": [
|
|
|
+ "# 展示模型结构\n",
|
|
|
"model"
|
|
|
]
|
|
|
},
|
|
|
@@ -325,15 +401,30 @@
|
|
|
"source": [
|
|
|
"@torch.no_grad()\n",
|
|
|
"def generate_batch(model, idx, max_new_tokens=300):\n",
|
|
|
+ " '''\n",
|
|
|
+ " 利用模型生成文本(反复使用模型进行预测)\n",
|
|
|
+ " 参数\n",
|
|
|
+ " ----\n",
|
|
|
+ " model :CharLSTM,生成文本的模型\n",
|
|
|
+ " idx :torch.LongTensor,当前字母在字典中的位置,形状为(1, T)\n",
|
|
|
+ " max_new_tokens :int,生成文本的最大长度\n",
|
|
|
+ " 返回\n",
|
|
|
+ " ----\n",
|
|
|
+ " out :list[int],生成的文本\n",
|
|
|
+ " '''\n",
|
|
|
" # 将模型切换至评估模式\n",
|
|
|
" model.eval()\n",
|
|
|
" for _ in range(max_new_tokens):\n",
|
|
|
" # 限制背景长度,使之与模型训练时的状况更相符\n",
|
|
|
" # 当然也可以不限制\n",
|
|
|
" context = idx[:, -sequence_len:]\n",
|
|
|
+ " # 在文本生成时,模型的计算效率很低,因为有很多重复计算\n",
|
|
|
" logits = model(context)\n",
|
|
|
+ " # 只使用最后一个预测结果\n",
|
|
|
" logits = logits[:, -1, :]\n",
|
|
|
" probs = F.softmax(logits, dim=-1)\n",
|
|
|
+ " # 根据模型预测的概率,得到最终的预测结果(下一个字母)\n",
|
|
|
+ " # 这一步运算有一定随机性\n",
|
|
|
" ix = torch.multinomial(probs, num_samples=1)\n",
|
|
|
" idx = torch.cat((idx, ix), dim=1)\n",
|
|
|
" if ix.item() == 0:\n",
|
|
|
@@ -364,6 +455,7 @@
|
|
|
}
|
|
|
],
|
|
|
"source": [
|
|
|
+ "# 使用模型来生成文本\n",
|
|
|
"begin_text = torch.tensor(tok.encode('def'), device=device).unsqueeze(0)\n",
|
|
|
"print(''.join(tok.decode(generate_batch(model, begin_text))))"
|
|
|
]
|
|
|
@@ -392,17 +484,26 @@
|
|
|
],
|
|
|
"source": [
|
|
|
"def process(data, sequence_len=sequence_len):\n",
|
|
|
+ " '''\n",
|
|
|
+ " 根据文本生成训练数据\n",
|
|
|
+ " '''\n",
|
|
|
+ " # text是字符串列表\n",
|
|
|
" text = data['whole_func_string']\n",
|
|
|
" inputs, labels = [], []\n",
|
|
|
" for i in text:\n",
|
|
|
" enc = tok.encode(i)\n",
|
|
|
+ " # 0对应着文本结束\n",
|
|
|
" enc += [0]\n",
|
|
|
+ " # 将文本转换为多个训练数据\n",
|
|
|
" for i in range(len(enc) - sequence_len):\n",
|
|
|
" inputs.append(enc[i: i + sequence_len])\n",
|
|
|
+ " # 预测标签是下一个字母,因此只需要挪动一个位置即可\n",
|
|
|
" labels.append(enc[i + 1: i + 1 + sequence_len])\n",
|
|
|
" return {'inputs': inputs, 'labels': labels}\n",
|
|
|
"\n",
|
|
|
+ "# 将数据分为训练集和测试集\n",
|
|
|
"tokenized = datasets.train_test_split(test_size=0.1, seed=1024, shuffle=True)\n",
|
|
|
+ "# 将文本转换为训练数据,里面包含inputs和labels\n",
|
|
|
"tokenized = tokenized.map(process, batched=True, remove_columns=datasets.column_names)\n",
|
|
|
"tokenized.set_format(type='torch', device=device)\n",
|
|
|
"\n",
|
|
|
@@ -492,6 +593,7 @@
|
|
|
" \"\"\"\n",
|
|
|
" loss = []\n",
|
|
|
" data_iter= iter(data_loader)\n",
|
|
|
+ " # 随机使用多个批量数据来预估模型效果\n",
|
|
|
" for k in range(eval_iters):\n",
|
|
|
" data = next(data_iter, None)\n",
|
|
|
" if data is None:\n",
|
|
|
@@ -499,6 +601,8 @@
|
|
|
" data = next(data_iter, None)\n",
|
|
|
" inputs, labels = data['inputs'], data['labels']\n",
|
|
|
" logits = model(inputs)\n",
|
|
|
+ " # 根据cross_entropy的定义,需要对logits进行转置运算\n",
|
|
|
+ " # 具体细节请参考cross_entropy的官方文档\n",
|
|
|
" logits = logits.transpose(-2, -1)\n",
|
|
|
" loss.append(F.cross_entropy(logits, labels).item())\n",
|
|
|
" return torch.tensor(loss).mean().item()\n",
|
|
|
@@ -521,6 +625,8 @@
|
|
|
" inputs, labels = data['inputs'], data['labels']\n",
|
|
|
" optimizer.zero_grad()\n",
|
|
|
" logits = model(inputs)\n",
|
|
|
+ " # 根据cross_entropy的定义,需要对logits进行转置运算\n",
|
|
|
+ " # 具体细节请参考cross_entropy的官方文档\n",
|
|
|
" logits = logits.transpose(-2, -1)\n",
|
|
|
" loss = F.cross_entropy(logits, labels)\n",
|
|
|
" lossi.append(loss.item())\n",
|
|
|
@@ -624,6 +730,7 @@
|
|
|
}
|
|
|
],
|
|
|
"source": [
|
|
|
+ "# 使用模型来生成文本\n",
|
|
|
"begin_text = torch.tensor(tok.encode('def'), device=device).unsqueeze(0)\n",
|
|
|
"print(''.join(tok.decode(generate_batch(model, begin_text))))"
|
|
|
]
|
|
|
@@ -640,25 +747,51 @@
|
|
|
"class LSTMLayerNormCell(nn.Module):\n",
|
|
|
"\n",
|
|
|
" def __init__(self, input_size, hidden_size):\n",
|
|
|
+ " '''\n",
|
|
|
+ " 长短期记忆网络的神经元(内含层归一化)\n",
|
|
|
+ " 参数\n",
|
|
|
+ " ----\n",
|
|
|
+ " input_size :int,输入数据的特征长度\n",
|
|
|
+ " hidden_size :int,隐藏状态的特征长度\n",
|
|
|
+ " '''\n",
|
|
|
" super().__init__()\n",
|
|
|
" self.input_size = input_size\n",
|
|
|
" self.hidden_size = hidden_size\n",
|
|
|
" combined_size = self.input_size + self.hidden_size\n",
|
|
|
+ " # 将四个线性模块放在一起定义,使得代码更加简洁和高效\n",
|
|
|
" self.gates = nn.Linear(\n",
|
|
|
" combined_size, 4 * self.hidden_size, bias=False)\n",
|
|
|
+ " # 用于门的层归一化\n",
|
|
|
" self.ln_gates = nn.LayerNorm(4 * self.hidden_size)\n",
|
|
|
+ " # 用于细胞状态的层归一化\n",
|
|
|
" self.ln_c = nn.LayerNorm(self.hidden_size)\n",
|
|
|
"\n",
|
|
|
" def forward(self, inputs, state=None):\n",
|
|
|
- " B, _ = inputs.shape # (B, I)\n",
|
|
|
- " # state: ((B, H), (B, H))\n",
|
|
|
+ " '''\n",
|
|
|
+ " 向前传播\n",
|
|
|
+ " 参数\n",
|
|
|
+ " ----\n",
|
|
|
+ " inputs :torch.FloatTensor\n",
|
|
|
+ " 输入数据,形状为(B, I),其中B表示批量大小,I表示文字特征的长度(input_size)\n",
|
|
|
+ " state :tuple(torch.FloatTensor, torch.FloatTensor)\n",
|
|
|
+ " (隐藏状态,细胞状态),两个状态的形状都为(B, H),其中H表示隐藏状态的长度(hidden_size)\n",
|
|
|
+ " 返回\n",
|
|
|
+ " ----\n",
|
|
|
+ " hs :torch.FloatTensor,隐藏状态,形状为(B, H)\n",
|
|
|
+ " cs :torch.FloatTensor,细胞状态,形状为(B, H)\n",
|
|
|
+ " '''\n",
|
|
|
+ " B, _ = inputs.shape\n",
|
|
|
" if state is None:\n",
|
|
|
" state = self.init_state(B, inputs.device)\n",
|
|
|
" hs, cs = state\n",
|
|
|
" combined = torch.cat((inputs, hs), dim=1) # (B, I + H)\n",
|
|
|
+ " # 将四个线性模块分开\n",
|
|
|
" i, f, c, o = self.ln_gates(self.gates(combined)).chunk(4, 1)\n",
|
|
|
+ " # 输入门\n",
|
|
|
" ingate = F.sigmoid(i) # (B, H)\n",
|
|
|
+ " # 遗忘门\n",
|
|
|
" forgetgate = F.sigmoid(f) # (B, H)\n",
|
|
|
+ " # 输出门\n",
|
|
|
" outgate = F.sigmoid(o) # (B, H)\n",
|
|
|
" # 更新细胞状态\n",
|
|
|
" ncs = F.tanh(c) # (B, H)\n",
|
|
|
@@ -675,21 +808,38 @@
|
|
|
"class LSTMLayerNorm(nn.Module):\n",
|
|
|
"\n",
|
|
|
" def __init__(self, input_size, hidden_size):\n",
|
|
|
+ " '''\n",
|
|
|
+ " 单层的长短期记忆网络(支持批量计算且内含层归一化)\n",
|
|
|
+ " 参数\n",
|
|
|
+ " ----\n",
|
|
|
+ " input_size :int,输入数据的特征长度\n",
|
|
|
+ " hidden_size :int,隐藏状态的特征长度\n",
|
|
|
+ " '''\n",
|
|
|
" super().__init__()\n",
|
|
|
" self.input_size = input_size\n",
|
|
|
" self.hidden_size = hidden_size\n",
|
|
|
" self.lstm = LSTMLayerNormCell(self.input_size, self.hidden_size)\n",
|
|
|
"\n",
|
|
|
" def forward(self, inputs, state=None):\n",
|
|
|
+ " '''\n",
|
|
|
+ " 向前传播\n",
|
|
|
+ " 参数\n",
|
|
|
+ " ----\n",
|
|
|
+ " inputs :torch.FloatTensor\n",
|
|
|
+ " 输入数据的集合,形状为(B, T, C),其中B表示批量大小,T表示文本长度,C表示文字特征的长度(input_size)\n",
|
|
|
+ " state :tuple(torch.FloatTensor, torch.FloatTensor)\n",
|
|
|
+ " (初始的隐藏状态,初始的细胞状态),两个状态的形状都为(B, H),其中H表示隐藏状态的长度(hidden_size)\n",
|
|
|
+ " 返回\n",
|
|
|
+ " ----\n",
|
|
|
+ " hidden :torch.FloatTensor,所有隐藏状态的集合,形状为(B, T, H)\n",
|
|
|
+ " '''\n",
|
|
|
" re = []\n",
|
|
|
- " # B batch_size,\n",
|
|
|
- " # T sequence length,\n",
|
|
|
- " # C number of channels.\n",
|
|
|
" B, T, C = inputs.shape\n",
|
|
|
" inputs = inputs.transpose(0, 1) # (T, B, C)\n",
|
|
|
" for i in range(T):\n",
|
|
|
" state = self.lstm(inputs[i], state)\n",
|
|
|
- " re.append(state[0]) # state[0]: (B, H)\n",
|
|
|
+ " # 只记录隐藏状态,state[0]的形状为(B, H)\n",
|
|
|
+ " re.append(state[0])\n",
|
|
|
" result_tensor = torch.stack(re, dim=0) # (T, B, H)\n",
|
|
|
" return result_tensor.transpose(0, 1) # (B, T, H)"
|
|
|
]
|
|
|
@@ -705,6 +855,12 @@
|
|
|
"class CharLSTMLayerNorm(nn.Module):\n",
|
|
|
"\n",
|
|
|
" def __init__(self, vs):\n",
|
|
|
+ " '''\n",
|
|
|
+ " 三层的长短期记忆网络(内嵌层归一化)\n",
|
|
|
+ " 参数\n",
|
|
|
+ " ----\n",
|
|
|
+ " vs :int,字典大小\n",
|
|
|
+ " '''\n",
|
|
|
" super().__init__()\n",
|
|
|
" self.emb_size = 256\n",
|
|
|
" self.hidden_size = 128\n",
|
|
|
@@ -716,7 +872,15 @@
|
|
|
" self.h2o = nn.Linear(self.hidden_size, vs)\n",
|
|
|
"\n",
|
|
|
" def forward(self, x):\n",
|
|
|
- " # x: (B, T)\n",
|
|
|
+ " '''\n",
|
|
|
+ " 向前传播\n",
|
|
|
+ " 参数\n",
|
|
|
+ " ----\n",
|
|
|
+ " x :torch.LongTensor,当前字母在字典中的位置,形状为(B, T)\n",
|
|
|
+ " 返回\n",
|
|
|
+ " ----\n",
|
|
|
+ " output :torch.FloatTensor,预测结果的logits,形状为(B, T, vs)\n",
|
|
|
+ " '''\n",
|
|
|
" emb = self.embedding(x) # (B, T, C)\n",
|
|
|
" h = self.dp(self.lstm1(emb)) # (B, T, H)\n",
|
|
|
" h = self.dp(self.lstm2(h)) # (B, T, H)\n",
|
|
|
@@ -824,6 +988,7 @@
|
|
|
}
|
|
|
],
|
|
|
"source": [
|
|
|
+ "# 使用模型来生成文本\n",
|
|
|
"begin_text = torch.tensor(tok.encode('def '), device=device).unsqueeze(0)\n",
|
|
|
"print(''.join(tok.decode(generate_batch(model_norm, begin_text))))"
|
|
|
]
|