{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "8D5tOVWvG6Ss", "outputId": "ec1c6e86-e76d-4b83-d0a6-d84c64be2d4a" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 1 } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader\n", "from datasets import load_dataset\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "\n", "torch.manual_seed(12046)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "GwjXDvhCG6Su", "outputId": "34409fcb-65b3-48ea-c6bf-37d3345be90f" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "tensor([[[ 1.0185, -1.3091, 1.2908, 0.5276],\n", " [-0.2985, 1.6259, 2.0433, -0.6417],\n", " [ 0.8795, -1.0512, 1.1491, 0.6116],\n", " [ 0.2128, -0.5512, 0.0450, 0.5010]]])\n", "tensor([[[ 1.0185, -inf, -inf, -inf],\n", " [-0.2985, 1.6259, -inf, -inf],\n", " [ 0.8795, -1.0512, 1.1491, -inf],\n", " [ 0.2128, -0.5512, 0.0450, 0.5010]]])\n", "tensor([[[1.0000, 0.0000, 0.0000, 0.0000],\n", " [0.1274, 0.8726, 0.0000, 0.0000],\n", " [0.4074, 0.0591, 0.5335, 0.0000],\n", " [0.2743, 0.1278, 0.2319, 0.3659]]])\n" ] } ], "source": [ "# 展示mask的作用\n", "T = 4\n", "scores = torch.randn(1, T, T)\n", "print(scores)\n", "tril = torch.tril(torch.ones(T, T))\n", "scores = scores.masked_fill(tril == 0, float('-inf'))\n", "print(scores)\n", "print(F.softmax(scores, dim=-1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "2o_FK0sUWW-t", "outputId": "ecba1e99-1255-4387-ce6c-9394c25606fe" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "tensor(1.0026) tensor(1.0010) tensor(4.0152)\n", "tensor(1.0026) tensor(1.0010) tensor(1.0038)\n", "tensor([[0.0921, 0.1476, 0.1698, 0.4256, 0.0489, 0.0599, 0.0172, 0.0388]])\n", "tensor([[0., 0., 0., 1., 0., 0., 0., 0.]])\n" ] } ], "source": [ "# 展示对齐分数的方差放大效应\n", "B, T, head_size = 32, 100, 16\n", "\n", "k = torch.randn(B, T, head_size) # (B, T, H)\n", "q = torch.randn(B, T, head_size) # (B, T, H)\n", "scores = q @ k.transpose(-2, -1) # (B, T, T)\n", "print(k.std(), q.std(), scores.std())\n", "scores = scores / head_size ** 0.5\n", "print(k.std(), q.std(), scores.std())\n", "\n", "# Softmax函数在处理方差较大的数据时,会发生聚集效应(结果过于集中在一个点上)\n", "x = torch.randn(1, 8)\n", "print(torch.softmax(x, dim=-1))\n", "print(torch.softmax(1000 * x, dim=-1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ujPVSgIgG6Sv" }, "outputs": [], "source": [ "def attention(query, key, value, dropout, mask=None):\n", " # query, key, value都有相同的形状\n", " B, T, C = query.shape\n", " # (B, T, C) @ (B, C, T) --> (B, T, T)\n", " scores = query @ key.transpose(-2, -1) / (C ** 0.5)\n", " if mask is not None:\n", " # 如果没有mask,则表示词元可以使用左右两边的背景,也就是双向注意力\n", " # mask的形状是(T, T)\n", " scores = scores.masked_fill(mask == 0, float('-inf'))\n", " w_att = dropout(F.softmax(scores, dim=-1)) # (B, T, T)\n", " out = w_att @ value # (B, T, C)\n", " return out, w_att" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SbJ2z5U3G6Sw" }, "outputs": [], "source": [ "class MaskedAttention(nn.Module):\n", "\n", " def __init__(self, emb_size, head_size):\n", " super().__init__()\n", " self.key = nn.Linear(emb_size, head_size, bias=False)\n", " self.query = nn.Linear(emb_size, head_size, bias=False)\n", " self.value = nn.Linear(emb_size, head_size, bias=False)\n", " # 这个上三角矩阵不参与模型训练\n", " self.register_buffer(\n", " 'tril', torch.tril(torch.ones(sequence_len, sequence_len)))\n", " self.dropout = nn.Dropout(0.4)\n", "\n", " def forward(self, x):\n", " B, T, C = x.shape # C = emb_size\n", " q = self.query(x) # (B, T, H)\n", " k = self.key(x) # (B, T, H)\n", " v = self.value(x) # (B, T, H)\n", " mask = self.tril[:T, :T]\n", " out, _ = attention(q, k, v, self.dropout, mask)\n", " return out # (B, T, H)\n", "\n", "class MaskedMultiHeadAttention(nn.Module):\n", "\n", " def __init__(self, emb_size, head_size):\n", " super().__init__()\n", " assert(emb_size % head_size == 0)\n", " n_head = emb_size // head_size\n", " heads = [MaskedAttention(emb_size, head_size) for _ in range(n_head)]\n", " self.heads = nn.ModuleList(heads)\n", " self.proj = nn.Linear(emb_size, emb_size)\n", " self.dropout = nn.Dropout(0.4)\n", "\n", " def forward(self, x):\n", " out = torch.cat([h(x) for h in self.heads], dim=-1) # (B, T, emb_size)\n", " out = self.dropout(self.proj(out))\n", " return out\n", "\n", "class FeedForward(nn.Module):\n", "\n", " def __init__(self, emb_size):\n", " super().__init__()\n", " self.l1 = nn.Linear(emb_size, 4 * emb_size)\n", " self.l2 = nn.Linear(4 * emb_size, emb_size)\n", " self.dropout = nn.Dropout(0.4)\n", "\n", " def forward(self, x):\n", " x = F.gelu(self.l1(x))\n", " out = self.dropout(self.l2(x))\n", " return out\n", "\n", "class Block(nn.Module):\n", "\n", " def __init__(self, emb_size, head_size):\n", " super().__init__()\n", " self.mha = MaskedMultiHeadAttention(emb_size, head_size)\n", " self.ff = FeedForward(emb_size)\n", " self.ln1 = nn.LayerNorm(emb_size)\n", " self.ln2 = nn.LayerNorm(emb_size)\n", "\n", " def forward(self, x):\n", " # 残差连接\n", " x = x + self.mha(self.ln1(x))\n", " x = x + self.ff(self.ln2(x))\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KlDmjSCMG6Sx" }, "outputs": [], "source": [ "emb_size = 128\n", "head_size = 8\n", "n_layer = 12\n", "sequence_len = 64\n", "learning_rate = 1e-3\n", "eval_iters = 20\n", "batch_size=500\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fwr_fSVpG6Sx" }, "outputs": [], "source": [ "class CharGPT(nn.Module):\n", "\n", " def __init__(self, vs):\n", " super().__init__()\n", " self.token_embedding = nn.Embedding(vs, emb_size)\n", " self.position_embedding = nn.Embedding(sequence_len, emb_size)\n", " blocks = [Block(emb_size, head_size) for _ in range(n_layer)]\n", " self.blocks = nn.Sequential(*blocks)\n", " self.ln = nn.LayerNorm(emb_size)\n", " self.lm_head = nn.Linear(emb_size, vs)\n", "\n", " def forward(self, x):\n", " B, T = x.shape\n", " pos = torch.arange(0, T, dtype=torch.long, device=x.device)\n", " tok_emb = self.token_embedding(x) # (B, T, C)\n", " pos_emb = self.position_embedding(pos) # ( T, C)\n", " x = tok_emb + pos_emb # (B, T, C)\n", " x = self.blocks(x) # (B, T, C)\n", " x = self.ln(x) # (B, T, C)\n", " logits = self.lm_head(x) # (B, T, vs)\n", " return logits" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "8LkvZyGjG6Sx", "outputId": "022464c0-2193-49f8-9030-1547721fa247" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "98" ] }, "metadata": {}, "execution_count": 8 } ], "source": [ "raw_datasets = load_dataset('code_search_net', 'python')\n", "datasets = raw_datasets['train'].filter(lambda x: 'apache/spark' in x['repository_name'])\n", "\n", "class char_tokenizer:\n", "\n", " def __init__(self, data):\n", " # 数据中出现的所有字符构成字典\n", " chars = sorted(list(set(''.join(data))))\n", " # 预留一个位置给结尾的特殊字符\n", " self.char2ind = {s : i + 1 for i, s in enumerate(chars)}\n", " self.char2ind['<|e|>'] = 0\n", " self.ind2char = {i : s for s, i in self.char2ind.items()}\n", "\n", " def encode(self, text):\n", " return [self.char2ind[c] for c in text]\n", "\n", " def decode(self, enc):\n", " if isinstance(enc, int):\n", " return self.ind2char[enc]\n", " return [self.ind2char[i] for i in enc]\n", "\n", "tok = char_tokenizer(datasets['whole_func_string'])\n", "len(tok.char2ind)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "No_AQspwG6Sx", "outputId": "b4760951-b485-4f37-b041-12d90c41ea16" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "2408290 parameters\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "CharGPT(\n", " (token_embedding): Embedding(98, 128)\n", " (position_embedding): Embedding(64, 128)\n", " (blocks): Sequential(\n", " (0): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (1): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (2): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (3): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (4): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (5): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (6): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (7): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (8): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (9): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (10): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (11): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", " (ln): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (lm_head): Linear(in_features=128, out_features=98, bias=True)\n", ")" ] }, "metadata": {}, "execution_count": 9 } ], "source": [ "model = CharGPT(len(tok.char2ind)).to(device)\n", "print(f'{sum(p.numel() for p in model.parameters())} parameters')\n", "model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VOl4s229G6Sy" }, "outputs": [], "source": [ "@torch.no_grad()\n", "def generate_batch(model, idx, max_new_tokens=300):\n", " # 将模型切换至评估模式\n", " model.eval()\n", " for _ in range(max_new_tokens):\n", " # 限制背景长度,否则会报错\n", " context = idx[:, -sequence_len:]\n", " logits = model(context)\n", " logits = logits[:, -1, :]\n", " probs = F.softmax(logits, dim=-1)\n", " ix = torch.multinomial(probs, num_samples=1)\n", " idx = torch.cat((idx, ix), dim=1)\n", " if ix.item() == 0:\n", " break\n", " # 将模型切换至训练模式\n", " model.train()\n", " return idx.tolist()[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "HQizE-2mG6Sz", "outputId": "eea46f55-4fdf-49a4-d770-ea8246ca90d1" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "def* O(h/of(\"YP`soE f|dwöR:1'_v?Q9)Nsx/Q=CKf\\M:iKcaI%+Q3m\n" ] } ], "source": [ "begin_text = torch.tensor(tok.encode('def'), device=device).unsqueeze(0)\n", "print(''.join(tok.decode(generate_batch(model, begin_text))))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "q0Fw0TCyG6Sz", "outputId": "cd446838-d7c0-4c2f-8d2a-be6e19cb8651" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(torch.Size([605913, 64]), torch.Size([605913, 64]))" ] }, "metadata": {}, "execution_count": 12 } ], "source": [ "def process(data, sequence_len=sequence_len):\n", " text = data['whole_func_string']\n", " inputs, labels = [], []\n", " for i in text:\n", " enc = tok.encode(i)\n", " enc += [0]\n", " for i in range(len(enc) - sequence_len):\n", " inputs.append(enc[i: i + sequence_len])\n", " labels.append(enc[i + 1: i + 1 + sequence_len])\n", " return {'inputs': inputs, 'labels': labels}\n", "\n", "tokenized = datasets.train_test_split(test_size=0.1, seed=1024, shuffle=True)\n", "tokenized = tokenized.map(process, batched=True, remove_columns=datasets.column_names)\n", "tokenized.set_format(type='torch', device=device)\n", "\n", "tokenized['train']['inputs'].shape, tokenized['train']['labels'].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-x4i2q1iG6S0", "outputId": "f8f9d666-9f20-41e1-8295-b653acda6081" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "{'inputs': tensor([[ 2, 2, 2, ..., 2, 2, 4],\n", " [81, 80, 88, ..., 2, 2, 10],\n", " [ 4, 37, 84, ..., 2, 2, 2],\n", " ...,\n", " [75, 85, 2, ..., 70, 71, 84],\n", " [ 2, 2, 2, ..., 67, 78, 53],\n", " [87, 84, 67, ..., 89, 2, 38]], device='cuda:0'),\n", " 'labels': tensor([[ 2, 2, 32, ..., 2, 4, 4],\n", " [80, 88, 71, ..., 2, 10, 70],\n", " [37, 84, 71, ..., 2, 2, 2],\n", " ...,\n", " [85, 2, 72, ..., 71, 84, 75],\n", " [ 2, 2, 2, ..., 78, 53, 81],\n", " [84, 67, 86, ..., 2, 38, 53]], device='cuda:0')}" ] }, "metadata": {}, "execution_count": 13 } ], "source": [ "# 构建数据读取器\n", "train_loader = DataLoader(tokenized['train'], batch_size=batch_size, shuffle=True)\n", "test_loader = DataLoader(tokenized['test'], batch_size=batch_size, shuffle=True)\n", "# 获取一个批量的数据\n", "next(iter(test_loader))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "QD37LTDbG6S0", "outputId": "df95d847-e355-4e66-b22b-8f9386a71f42" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "{'train': 4.730087757110596, 'test': 4.726046562194824}" ] }, "metadata": {}, "execution_count": 14 } ], "source": [ "def estimate_loss(model):\n", " re = {}\n", " # 将模型切换至评估模式\n", " model.eval()\n", " re['train'] = _loss(model, train_loader)\n", " re['test'] = _loss(model, test_loader)\n", " # 将模型切换至训练模式\n", " model.train()\n", " return re\n", "\n", "@torch.no_grad()\n", "def _loss(model, data_loader):\n", " \"\"\"\n", " 计算模型在不同数据集下面的评估指标\n", " \"\"\"\n", " loss = []\n", " data_iter= iter(data_loader)\n", " for k in range(eval_iters):\n", " data = next(data_iter, None)\n", " if data is None:\n", " data_iter = iter(data_loader)\n", " data = next(data_iter, None)\n", " inputs, labels = data['inputs'], data['labels']\n", " logits = model(inputs)\n", " logits = logits.transpose(-2, -1)\n", " loss.append(F.cross_entropy(logits, labels).item())\n", " return torch.tensor(loss).mean().item()\n", "\n", "estimate_loss(model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TgKhC5TmG6S0" }, "outputs": [], "source": [ "def train_gpt(model, optimizer, data_loader, epochs=10):\n", " lossi = []\n", " for epoch in range(epochs):\n", " for i, data in enumerate(data_loader, 0):\n", " inputs, labels = data['inputs'], data['labels']\n", " optimizer.zero_grad()\n", " logits = model(inputs)\n", " logits = logits.transpose(-2, -1)\n", " loss = F.cross_entropy(logits, labels)\n", " lossi.append(loss.item())\n", " loss.backward()\n", " optimizer.step()\n", " # 评估模型,并输出结果\n", " stats = estimate_loss(model)\n", " train_loss = f'train loss {stats[\"train\"]:.4f}'\n", " test_loss = f'test loss {stats[\"test\"]:.4f}'\n", " print(f'epoch {epoch:>2}: {train_loss}, {test_loss}')\n", " return lossi" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "MCPIFH2dG6S1", "outputId": "b32049d3-4534-4c28-d86e-6cef0d4859f7" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "epoch 0: train loss 0.9037, test loss 1.1066\n", "epoch 1: train loss 0.7246, test loss 1.0086\n", "epoch 2: train loss 0.6448, test loss 0.9719\n", "epoch 3: train loss 0.5838, test loss 0.9607\n", "epoch 4: train loss 0.5468, test loss 0.9672\n", "epoch 5: train loss 0.5156, test loss 0.9663\n", "epoch 6: train loss 0.4891, test loss 0.9596\n", "epoch 7: train loss 0.4687, test loss 0.9652\n", "epoch 8: train loss 0.4517, test loss 0.9709\n", "epoch 9: train loss 0.4347, test loss 0.9761\n" ] } ], "source": [ "l = train_gpt(model, optim.AdamW(model.parameters(), lr=learning_rate), train_loader)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "pgRJxHwOG6S1", "colab": { "base_uri": "https://localhost:8080/", "height": 448 }, "outputId": "1ba354b1-57f0-4fc3-bf79-64baea7bf03e" }, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" }, { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "\n" }, "metadata": {} } ], "source": [ "plt.plot(torch.tensor(l).view(-1, 10).mean(1).numpy())" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "cPqZDUJ8I2AA", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "b602a445-1cd5-425e-dac1-6b2750ae0684" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "def _initialModel._to_java_impl():\n", " \"\"\"\n", " Deprecated in 2.3.0. Use :func:`pyspark.sql.types.DataType`, int or :class:`Column` expression in the given key (default param).\n", "\n", " >>> df = spark.range(1, 0).alias('age')).collect()\n", " [Row(name=u'Alice', age=1, name=u'Alice')]\n", " \"\n" ] } ], "source": [ "begin_text = torch.tensor(tok.encode('def '), device=device).unsqueeze(0)\n", "print(''.join(tok.decode(generate_batch(model, begin_text))))" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "V100", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 0 }