{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "8D5tOVWvG6Ss", "outputId": "ec1c6e86-e76d-4b83-d0a6-d84c64be2d4a" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 1 } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader\n", "from datasets import load_dataset\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "\n", "torch.manual_seed(12046)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "GwjXDvhCG6Su", "outputId": "34409fcb-65b3-48ea-c6bf-37d3345be90f" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "tensor([[[ 1.0185, -1.3091, 1.2908, 0.5276],\n", " [-0.2985, 1.6259, 2.0433, -0.6417],\n", " [ 0.8795, -1.0512, 1.1491, 0.6116],\n", " [ 0.2128, -0.5512, 0.0450, 0.5010]]])\n", "tensor([[[ 1.0185, -inf, -inf, -inf],\n", " [-0.2985, 1.6259, -inf, -inf],\n", " [ 0.8795, -1.0512, 1.1491, -inf],\n", " [ 0.2128, -0.5512, 0.0450, 0.5010]]])\n", "tensor([[[1.0000, 0.0000, 0.0000, 0.0000],\n", " [0.1274, 0.8726, 0.0000, 0.0000],\n", " [0.4074, 0.0591, 0.5335, 0.0000],\n", " [0.2743, 0.1278, 0.2319, 0.3659]]])\n" ] } ], "source": [ "# 展示mask的作用\n", "T = 4\n", "scores = torch.randn(1, T, T)\n", "print(scores)\n", "tril = torch.tril(torch.ones(T, T))\n", "scores = scores.masked_fill(tril == 0, float('-inf'))\n", "print(scores)\n", "print(F.softmax(scores, dim=-1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "2o_FK0sUWW-t", "outputId": "ecba1e99-1255-4387-ce6c-9394c25606fe" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "tensor(1.0026) tensor(1.0010) tensor(4.0152)\n", "tensor(1.0026) tensor(1.0010) tensor(1.0038)\n", "tensor([[0.0921, 0.1476, 0.1698, 0.4256, 0.0489, 0.0599, 0.0172, 0.0388]])\n", "tensor([[0., 0., 0., 1., 0., 0., 0., 0.]])\n" ] } ], "source": [ "# 展示对齐分数的方差放大效应\n", "B, T, head_size = 32, 100, 16\n", "\n", "k = torch.randn(B, T, head_size) # (B, T, H)\n", "q = torch.randn(B, T, head_size) # (B, T, H)\n", "scores = q @ k.transpose(-2, -1) # (B, T, T)\n", "print(k.std(), q.std(), scores.std())\n", "scores = scores / head_size ** 0.5\n", "print(k.std(), q.std(), scores.std())\n", "\n", "# Softmax函数在处理方差较大的数据时,会发生聚集效应(结果过于集中在一个点上)\n", "x = torch.randn(1, 8)\n", "print(torch.softmax(x, dim=-1))\n", "print(torch.softmax(1000 * x, dim=-1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ujPVSgIgG6Sv" }, "outputs": [], "source": [ "def attention(query, key, value, dropout, mask=None):\n", " # query, key, value都有相同的形状\n", " B, T, C = query.shape\n", " # (B, T, C) @ (B, C, T) --> (B, T, T)\n", " scores = query @ key.transpose(-2, -1) / (C ** 0.5)\n", " if mask is not None:\n", " # 如果没有mask,则表示词元可以使用左右两边的背景,也就是双向注意力\n", " # mask的形状是(T, T)\n", " scores = scores.masked_fill(mask == 0, float('-inf'))\n", " w_att = dropout(F.softmax(scores, dim=-1)) # (B, T, T)\n", " out = w_att @ value # (B, T, C)\n", " return out, w_att" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SbJ2z5U3G6Sw" }, "outputs": [], "source": [ "class MaskedAttention(nn.Module):\n", "\n", " def __init__(self, emb_size, head_size):\n", " super().__init__()\n", " self.key = nn.Linear(emb_size, head_size, bias=False)\n", " self.query = nn.Linear(emb_size, head_size, bias=False)\n", " self.value = nn.Linear(emb_size, head_size, bias=False)\n", " # 这个上三角矩阵不参与模型训练\n", " self.register_buffer(\n", " 'tril', torch.tril(torch.ones(sequence_len, sequence_len)))\n", " self.dropout = nn.Dropout(0.4)\n", "\n", " def forward(self, x):\n", " B, T, C = x.shape # C = emb_size\n", " q = self.query(x) # (B, T, H)\n", " k = self.key(x) # (B, T, H)\n", " v = self.value(x) # (B, T, H)\n", " mask = self.tril[:T, :T]\n", " out, _ = attention(q, k, v, self.dropout, mask)\n", " return out # (B, T, H)\n", "\n", "class MaskedMultiHeadAttention(nn.Module):\n", "\n", " def __init__(self, emb_size, head_size):\n", " super().__init__()\n", " assert(emb_size % head_size == 0)\n", " n_head = emb_size // head_size\n", " heads = [MaskedAttention(emb_size, head_size) for _ in range(n_head)]\n", " self.heads = nn.ModuleList(heads)\n", " self.proj = nn.Linear(emb_size, emb_size)\n", " self.dropout = nn.Dropout(0.4)\n", "\n", " def forward(self, x):\n", " out = torch.cat([h(x) for h in self.heads], dim=-1) # (B, T, emb_size)\n", " out = self.dropout(self.proj(out))\n", " return out\n", "\n", "class FeedForward(nn.Module):\n", "\n", " def __init__(self, emb_size):\n", " super().__init__()\n", " self.l1 = nn.Linear(emb_size, 4 * emb_size)\n", " self.l2 = nn.Linear(4 * emb_size, emb_size)\n", " self.dropout = nn.Dropout(0.4)\n", "\n", " def forward(self, x):\n", " x = F.gelu(self.l1(x))\n", " out = self.dropout(self.l2(x))\n", " return out\n", "\n", "class Block(nn.Module):\n", "\n", " def __init__(self, emb_size, head_size):\n", " super().__init__()\n", " self.mha = MaskedMultiHeadAttention(emb_size, head_size)\n", " self.ff = FeedForward(emb_size)\n", " self.ln1 = nn.LayerNorm(emb_size)\n", " self.ln2 = nn.LayerNorm(emb_size)\n", "\n", " def forward(self, x):\n", " # 残差连接\n", " x = x + self.mha(self.ln1(x))\n", " x = x + self.ff(self.ln2(x))\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KlDmjSCMG6Sx" }, "outputs": [], "source": [ "emb_size = 128\n", "head_size = 8\n", "n_layer = 12\n", "sequence_len = 64\n", "learning_rate = 1e-3\n", "eval_iters = 20\n", "batch_size=500\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fwr_fSVpG6Sx" }, "outputs": [], "source": [ "class CharGPT(nn.Module):\n", "\n", " def __init__(self, vs):\n", " super().__init__()\n", " self.token_embedding = nn.Embedding(vs, emb_size)\n", " self.position_embedding = nn.Embedding(sequence_len, emb_size)\n", " blocks = [Block(emb_size, head_size) for _ in range(n_layer)]\n", " self.blocks = nn.Sequential(*blocks)\n", " self.ln = nn.LayerNorm(emb_size)\n", " self.lm_head = nn.Linear(emb_size, vs)\n", "\n", " def forward(self, x):\n", " B, T = x.shape\n", " pos = torch.arange(0, T, dtype=torch.long, device=x.device)\n", " tok_emb = self.token_embedding(x) # (B, T, C)\n", " pos_emb = self.position_embedding(pos) # ( T, C)\n", " x = tok_emb + pos_emb # (B, T, C)\n", " x = self.blocks(x) # (B, T, C)\n", " x = self.ln(x) # (B, T, C)\n", " logits = self.lm_head(x) # (B, T, vs)\n", " return logits" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "8LkvZyGjG6Sx", "outputId": "022464c0-2193-49f8-9030-1547721fa247" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "98" ] }, "metadata": {}, "execution_count": 8 } ], "source": [ "raw_datasets = load_dataset('code_search_net', 'python')\n", "datasets = raw_datasets['train'].filter(lambda x: 'apache/spark' in x['repository_name'])\n", "\n", "class char_tokenizer:\n", "\n", " def __init__(self, data):\n", " # 数据中出现的所有字符构成字典\n", " chars = sorted(list(set(''.join(data))))\n", " # 预留一个位置给结尾的特殊字符\n", " self.char2ind = {s : i + 1 for i, s in enumerate(chars)}\n", " self.char2ind['<|e|>'] = 0\n", " self.ind2char = {i : s for s, i in self.char2ind.items()}\n", "\n", " def encode(self, text):\n", " return [self.char2ind[c] for c in text]\n", "\n", " def decode(self, enc):\n", " if isinstance(enc, int):\n", " return self.ind2char[enc]\n", " return [self.ind2char[i] for i in enc]\n", "\n", "tok = char_tokenizer(datasets['whole_func_string'])\n", "len(tok.char2ind)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "No_AQspwG6Sx", "outputId": "b4760951-b485-4f37-b041-12d90c41ea16" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "2408290 parameters\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "CharGPT(\n", " (token_embedding): Embedding(98, 128)\n", " (position_embedding): Embedding(64, 128)\n", " (blocks): Sequential(\n", " (0): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (1): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (2): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (3): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (4): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (5): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (6): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (7): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (8): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (9): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (10): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (11): Block(\n", " (mha): MaskedMultiHeadAttention(\n", " (heads): ModuleList(\n", " (0-15): 16 x MaskedAttention(\n", " (key): Linear(in_features=128, out_features=8, bias=False)\n", " (query): Linear(in_features=128, out_features=8, bias=False)\n", " (value): Linear(in_features=128, out_features=8, bias=False)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " )\n", " (proj): Linear(in_features=128, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ff): FeedForward(\n", " (l1): Linear(in_features=128, out_features=512, bias=True)\n", " (l2): Linear(in_features=512, out_features=128, bias=True)\n", " (dropout): Dropout(p=0.4, inplace=False)\n", " )\n", " (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", " (ln): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (lm_head): Linear(in_features=128, out_features=98, bias=True)\n", ")" ] }, "metadata": {}, "execution_count": 9 } ], "source": [ "model = CharGPT(len(tok.char2ind)).to(device)\n", "print(f'{sum(p.numel() for p in model.parameters())} parameters')\n", "model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VOl4s229G6Sy" }, "outputs": [], "source": [ "@torch.no_grad()\n", "def generate_batch(model, idx, max_new_tokens=300):\n", " # 将模型切换至评估模式\n", " model.eval()\n", " for _ in range(max_new_tokens):\n", " # 限制背景长度,否则会报错\n", " context = idx[:, -sequence_len:]\n", " logits = model(context)\n", " logits = logits[:, -1, :]\n", " probs = F.softmax(logits, dim=-1)\n", " ix = torch.multinomial(probs, num_samples=1)\n", " idx = torch.cat((idx, ix), dim=1)\n", " if ix.item() == 0:\n", " break\n", " # 将模型切换至训练模式\n", " model.train()\n", " return idx.tolist()[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "HQizE-2mG6Sz", "outputId": "eea46f55-4fdf-49a4-d770-ea8246ca90d1" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "def* O(h/of(\"YP`soE f|dwöR:1'_v?Q9)Nsx/Q=CKf\\M:iKcaI%+Q3m\n" ] } ], "source": [ "begin_text = torch.tensor(tok.encode('def'), device=device).unsqueeze(0)\n", "print(''.join(tok.decode(generate_batch(model, begin_text))))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "q0Fw0TCyG6Sz", "outputId": "cd446838-d7c0-4c2f-8d2a-be6e19cb8651" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(torch.Size([605913, 64]), torch.Size([605913, 64]))" ] }, "metadata": {}, "execution_count": 12 } ], "source": [ "def process(data, sequence_len=sequence_len):\n", " text = data['whole_func_string']\n", " inputs, labels = [], []\n", " for i in text:\n", " enc = tok.encode(i)\n", " enc += [0]\n", " for i in range(len(enc) - sequence_len):\n", " inputs.append(enc[i: i + sequence_len])\n", " labels.append(enc[i + 1: i + 1 + sequence_len])\n", " return {'inputs': inputs, 'labels': labels}\n", "\n", "tokenized = datasets.train_test_split(test_size=0.1, seed=1024, shuffle=True)\n", "tokenized = tokenized.map(process, batched=True, remove_columns=datasets.column_names)\n", "tokenized.set_format(type='torch', device=device)\n", "\n", "tokenized['train']['inputs'].shape, tokenized['train']['labels'].shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-x4i2q1iG6S0", "outputId": "f8f9d666-9f20-41e1-8295-b653acda6081" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "{'inputs': tensor([[ 2, 2, 2, ..., 2, 2, 4],\n", " [81, 80, 88, ..., 2, 2, 10],\n", " [ 4, 37, 84, ..., 2, 2, 2],\n", " ...,\n", " [75, 85, 2, ..., 70, 71, 84],\n", " [ 2, 2, 2, ..., 67, 78, 53],\n", " [87, 84, 67, ..., 89, 2, 38]], device='cuda:0'),\n", " 'labels': tensor([[ 2, 2, 32, ..., 2, 4, 4],\n", " [80, 88, 71, ..., 2, 10, 70],\n", " [37, 84, 71, ..., 2, 2, 2],\n", " ...,\n", " [85, 2, 72, ..., 71, 84, 75],\n", " [ 2, 2, 2, ..., 78, 53, 81],\n", " [84, 67, 86, ..., 2, 38, 53]], device='cuda:0')}" ] }, "metadata": {}, "execution_count": 13 } ], "source": [ "# 构建数据读取器\n", "train_loader = DataLoader(tokenized['train'], batch_size=batch_size, shuffle=True)\n", "test_loader = DataLoader(tokenized['test'], batch_size=batch_size, shuffle=True)\n", "# 获取一个批量的数据\n", "next(iter(test_loader))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "QD37LTDbG6S0", "outputId": "df95d847-e355-4e66-b22b-8f9386a71f42" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "{'train': 4.730087757110596, 'test': 4.726046562194824}" ] }, "metadata": {}, "execution_count": 14 } ], "source": [ "def estimate_loss(model):\n", " re = {}\n", " # 将模型切换至评估模式\n", " model.eval()\n", " re['train'] = _loss(model, train_loader)\n", " re['test'] = _loss(model, test_loader)\n", " # 将模型切换至训练模式\n", " model.train()\n", " return re\n", "\n", "@torch.no_grad()\n", "def _loss(model, data_loader):\n", " \"\"\"\n", " 计算模型在不同数据集下面的评估指标\n", " \"\"\"\n", " loss = []\n", " data_iter= iter(data_loader)\n", " for k in range(eval_iters):\n", " data = next(data_iter, None)\n", " if data is None:\n", " data_iter = iter(data_loader)\n", " data = next(data_iter, None)\n", " inputs, labels = data['inputs'], data['labels']\n", " logits = model(inputs)\n", " logits = logits.transpose(-2, -1)\n", " loss.append(F.cross_entropy(logits, labels).item())\n", " return torch.tensor(loss).mean().item()\n", "\n", "estimate_loss(model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TgKhC5TmG6S0" }, "outputs": [], "source": [ "def train_gpt(model, optimizer, data_loader, epochs=10):\n", " lossi = []\n", " for epoch in range(epochs):\n", " for i, data in enumerate(data_loader, 0):\n", " inputs, labels = data['inputs'], data['labels']\n", " optimizer.zero_grad()\n", " logits = model(inputs)\n", " logits = logits.transpose(-2, -1)\n", " loss = F.cross_entropy(logits, labels)\n", " lossi.append(loss.item())\n", " loss.backward()\n", " optimizer.step()\n", " # 评估模型,并输出结果\n", " stats = estimate_loss(model)\n", " train_loss = f'train loss {stats[\"train\"]:.4f}'\n", " test_loss = f'test loss {stats[\"test\"]:.4f}'\n", " print(f'epoch {epoch:>2}: {train_loss}, {test_loss}')\n", " return lossi" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "MCPIFH2dG6S1", "outputId": "b32049d3-4534-4c28-d86e-6cef0d4859f7" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "epoch 0: train loss 0.9037, test loss 1.1066\n", "epoch 1: train loss 0.7246, test loss 1.0086\n", "epoch 2: train loss 0.6448, test loss 0.9719\n", "epoch 3: train loss 0.5838, test loss 0.9607\n", "epoch 4: train loss 0.5468, test loss 0.9672\n", "epoch 5: train loss 0.5156, test loss 0.9663\n", "epoch 6: train loss 0.4891, test loss 0.9596\n", "epoch 7: train loss 0.4687, test loss 0.9652\n", "epoch 8: train loss 0.4517, test loss 0.9709\n", "epoch 9: train loss 0.4347, test loss 0.9761\n" ] } ], "source": [ "l = train_gpt(model, optim.AdamW(model.parameters(), lr=learning_rate), train_loader)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "pgRJxHwOG6S1", "colab": { "base_uri": "https://localhost:8080/", "height": 448 }, "outputId": "1ba354b1-57f0-4fc3-bf79-64baea7bf03e" }, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" }, { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA5R0lEQVR4nO3deXxU9b3/8fcsyUxCMpONZBKSsAiC7AgiAa9bI4hcK22vtdZbcO2lF+6V2kXRavvz/my819veLtfrclvlZy1irYItUhRBoJbIHiSgLLIkhEzCksxk3+b8/ggZCJCQQDInybyej8c8ZM4y8znfBzLvx/d8v99jMQzDEAAAgEmsZhcAAADCG2EEAACYijACAABMRRgBAACmIowAAABTEUYAAICpCCMAAMBUhBEAAGAqu9kFdEQgENCxY8cUGxsri8VidjkAAKADDMNQRUWF0tLSZLW23f/RK8LIsWPHlJGRYXYZAADgEhQWFio9Pb3N/b0ijMTGxkpqvhiXy2VyNQAAoCP8fr8yMjKCv+Nt6RVhpOXWjMvlIowAANDLXGyIBQNYAQCAqQgjAADAVIQRAABgKsIIAAAwFWEEAACYijACAABMRRgBAACmIowAAABTEUYAAICpCCMAAMBUhBEAAGAqwggAADBVr3hQXnf5zV8P6mhZjb4xOUMjPDyADwAAM4R1z8h7u4q1eONhFZysNrsUAADCVliHEevpRxoHDJMLAQAgjIV1GLGc/q9hkEYAADBLWIeRlp4RoggAAOYJ6zDS0jUSoGcEAADThHUYsZ4OI2QRAADME9ZhxKKWAaykEQAAzBLWYcQa1lcPAEDPENY/x2em9tIzAgCAWcI6jLQIBMyuAACA8BXWYYSpvQAAmC+sw4iFqb0AAJgurMNIS88IXSMAAJgnzMNI83/pGQEAwDxhHUZalmAligAAYJ6wDiP0jAAAYL6wDiNnBrCaWwcAAOEsrMPImQGspBEAAMxCGBE9IwAAmCmsw4iCHSOkEQAAzBLWYYSeEQAAzBfWYeR0xwhTewEAMFFYhxErt2kAADBdWIcRS/A2DWEEAACzhHkYaf4vWQQAAPOEdRhhACsAAOYL6zByZgAraQQAALOEdRhp6RnhNg0AAOYJ6zBiYTYNAACmC/MwwpgRAADMFtZhxBp8ai9pBAAAs4R1GGFqLwAA5gvrMHJmACtpBAAAs4R1GOHZNAAAmC+8wwjLwQMAYLqwDiOsMwIAgPnCOoxYgrNpzK0DAIBwFtZhxMqiZwAAmC6sw0jLmBGiCAAA5gnzMNL83wD3aQAAME2nwsgLL7ygsWPHyuVyyeVyKSsrS3/5y1/aPeett97SiBEj5HQ6NWbMGK1cufKyCu5KFtEzAgCA2ToVRtLT0/Xss89q27Zt2rp1q26++Wbdcccd2r179wWP37hxo+6++2498MAD2rFjh2bPnq3Zs2crPz+/S4q/XCwHDwCA+SzGZY7eTEhI0HPPPacHHnjgvH133XWXqqqqtGLFiuC2KVOmaPz48XrxxRc7/B1+v19ut1s+n08ul+tyym3lP9/fq//+6IDunTpIP/nyqC77XAAA0PHf70seM9LU1KSlS5eqqqpKWVlZFzwmNzdX2dnZrbbNmDFDubm57X52XV2d/H5/q1d3sDCbBgAA03U6jOzatUsxMTFyOByaN2+eli1bppEjR17wWK/Xq5SUlFbbUlJS5PV62/2OnJwcud3u4CsjI6OzZXbImRVYu+XjAQBAB3Q6jAwfPlx5eXnatGmTvvOd72ju3Lnas2dPlxa1aNEi+Xy+4KuwsLBLP7/FmWfTkEYAADCLvbMnREZGaujQoZKkiRMnasuWLfrlL3+pl1566bxjPR6PSkpKWm0rKSmRx+Np9zscDoccDkdnS+s0Kz0jAACY7rLXGQkEAqqrq7vgvqysLK1Zs6bVttWrV7c5xiTUzqzAam4dAACEs071jCxatEgzZ85UZmamKioqtGTJEq1bt07vv/++JGnOnDkaMGCAcnJyJEkPP/ywbrjhBv3sZz/TrFmztHTpUm3dulUvv/xy11/JJWAAKwAA5utUGCktLdWcOXNUXFwst9utsWPH6v3339ctt9wiSSooKJDVeqazZerUqVqyZIl+9KMf6fHHH9ewYcO0fPlyjR49umuv4hJZeGovAACm61QY+e1vf9vu/nXr1p237c4779Sdd97ZqaJCxcKiZwAAmC6sn03DAFYAAMwX5mGk+b9M7QUAwDxhHUaCD8ojiwAAYJrwDiPMpgEAwHRhHkYYMwIAgNnCOoycGTMCAADMEtZhpOXZNEztBQDAPGEdRqzWlgGshBEAAMwS1mGEFVgBADBfeIeR0/9tYgQrAACmCeswYre2zKYhjAAAYJbwDiO25stvpGcEAADThHcYOd0z0thEGAEAwCzhHUZsp8NIIGByJQAAhK/wDiOne0YYwAoAgHnCOozYrM2X38BtGgAATBPWYaTlNg09IwAAmCe8w8jp2zQNTYwZAQDALGEdRmyMGQEAwHRhHUYiTq8zQhgBAMA8YR1GWnpGGpjaCwCAacI6jASn9jKbBgAA04R5GGE5eAAAzBbeYSS4AithBAAAs4R3GAk+m4YxIwAAmCXMwwizaQAAMFtYhxGbrWU2DWEEAACzhHUYiWDRMwAATBfWYeTsFVgNg0ACAIAZwjqMtIwZkZhRAwCAWcI7jJweMyJxqwYAALOEdRhpuU0j8eReAADMEtZhJNJ25vLrGwkjAACYIazDiNVqCQaSOsIIAACmCOswIkmR9uYmoGcEAABzhH0YcdjpGQEAwExhH0boGQEAwFxhH0bO9Iw0mVwJAADhKezDCD0jAACYizDS0jPCOiMAAJgi7MOIw26TJNU1EEYAADBD2IeRlnVG6ukZAQDAFGEfRhwRp2/TNDCAFQAAM4R9GKFnBAAAc4V9GHFEMGYEAAAzhX0YoWcEAABzhX0YOTNmhDACAIAZwj6MnOkZYQArAABmCPswElwOnp4RAABMQRixM2YEAAAzdSqM5OTk6JprrlFsbKySk5M1e/Zs7d27t91zFi9eLIvF0urldDovq+iuxLNpAAAwV6fCyPr16zV//nx98sknWr16tRoaGjR9+nRVVVW1e57L5VJxcXHwdeTIkcsquisFl4MnjAAAYAp7Zw5etWpVq/eLFy9WcnKytm3bpuuvv77N8ywWizwez6VV2M3oGQEAwFyXNWbE5/NJkhISEto9rrKyUgMHDlRGRobuuOMO7d69u93j6+rq5Pf7W726S3AAayOzaQAAMMMlh5FAIKCFCxdq2rRpGj16dJvHDR8+XK+88oreffddvf766woEApo6daqOHj3a5jk5OTlyu93BV0ZGxqWWeVGRwTBCzwgAAGa45DAyf/585efna+nSpe0el5WVpTlz5mj8+PG64YYb9M4776h///566aWX2jxn0aJF8vl8wVdhYeGllnlRjBkBAMBcnRoz0mLBggVasWKFNmzYoPT09E6dGxERoQkTJujAgQNtHuNwOORwOC6ltE6jZwQAAHN1qmfEMAwtWLBAy5Yt09q1azV48OBOf2FTU5N27dql1NTUTp/bHaKCD8pjzAgAAGboVM/I/PnztWTJEr377ruKjY2V1+uVJLndbkVFRUmS5syZowEDBignJ0eS9PTTT2vKlCkaOnSoysvL9dxzz+nIkSN68MEHu/hSLk20ozmMVNY1mlwJAADhqVNh5IUXXpAk3Xjjja22v/rqq7r33nslSQUFBbJaz3S4lJWV6aGHHpLX61V8fLwmTpyojRs3auTIkZdXeReJcTQ3QXU9PSMAAJjBYhiGYXYRF+P3++V2u+Xz+eRyubr0s4+WVeu6f/9IDrtVe//vzC79bAAAwllHf7/D/tk0/SKbe0bqGgNq5Pk0AACEHGHEceZOVRW3agAACLmwDyORdqsibBZJUnU9g1gBAAi1sA8j0pnekSpm1AAAEHKEEUkuZ4QkyVdDGAEAINQII5LiolvCSL3JlQAAEH4II5LcUc1hpLy6weRKAAAIP4QRSXHRkZIIIwAAmIEwIimupWekhjACAECoEUZ01piRasaMAAAQaoQRnTVmhJ4RAABCjjAixowAAGAmwogYMwIAgJkII2LMCAAAZiKMiDEjAACYiTAiyR1cgbVBgYBhcjUAAIQXwojO9IwYhlRRy/NpAAAIJcKIJIfdpuhImySpnOfTAAAQUoSR0+J4Pg0AAKYgjJyWENO81kjuwZMmVwIAQHghjJz2pREpkqRPCCMAAIQUYeS0qwfGS5KKy2tNrgQAgPBCGDltQJxTknTMV2NyJQAAhBfCyGmp7ihJzVN7K2oZxAoAQKgQRk7r57DL5bRLkop93KoBACBUCCNnSYtr7h05Vs6tGgAAQoUwcpaWMELPCAAAoUMYOUuKq3kQa4mfMAIAQKgQRs6S4nJIIowAABBKhJGzeII9I3UmVwIAQPggjJyF2zQAAIQeYeQsydymAQAg5AgjZ2npGTlRWa+GpoDJ1QAAEB4II2dJiI5UhM0iSTpewbgRAABCgTByFqvVouTY5t4RL7dqAAAICcLIOVrGjZQSRgAACAnCyDlSYpneCwBAKBFGzsHCZwAAhBZh5BwJ/ZrDyKmqepMrAQAgPBBGzpEQEymJMAIAQKgQRs6REN0cRsqqCSMAAIQCYeQcCf2aw8hJekYAAAgJwsg5Ek/fpikjjAAAEBKEkXPEn75NU17ToKaAYXI1AAD0fYSRc8RHR0iSDEMqZ9wIAADdjjByDrvNKndUcyBhECsAAN2PMHIBiS2DWCsJIwAAdDfCyAW0DGI9QRgBAKDbEUYuICmmeRXWE5U8nwYAgO5GGLmAljByvIIwAgBAd+tUGMnJydE111yj2NhYJScna/bs2dq7d+9Fz3vrrbc0YsQIOZ1OjRkzRitXrrzkgkOhfyw9IwAAhEqnwsj69es1f/58ffLJJ1q9erUaGho0ffp0VVVVtXnOxo0bdffdd+uBBx7Qjh07NHv2bM2ePVv5+fmXXXx3ie/HkvAAAISKxTCMS17Z6/jx40pOTtb69et1/fXXX/CYu+66S1VVVVqxYkVw25QpUzR+/Hi9+OKLHfoev98vt9stn88nl8t1qeV22J92HtO/vrFDU4YkaOm3s7r9+wAA6Is6+vt9WWNGfD6fJCkhIaHNY3Jzc5Wdnd1q24wZM5Sbm9vmOXV1dfL7/a1eodSyzoivpjGk3wsAQDi65DASCAS0cOFCTZs2TaNHj27zOK/Xq5SUlFbbUlJS5PV62zwnJydHbrc7+MrIyLjUMi9JSxjx1zSE9HsBAAhHlxxG5s+fr/z8fC1durQr65EkLVq0SD6fL/gqLCzs8u9oz5meEcIIAADdzX4pJy1YsEArVqzQhg0blJ6e3u6xHo9HJSUlrbaVlJTI4/G0eY7D4ZDD4biU0rpESxiprGtUY1NAdhszoAEA6C6d+pU1DEMLFizQsmXLtHbtWg0ePPii52RlZWnNmjWttq1evVpZWT13YKjLeSaj+WsZNwIAQHfqVBiZP3++Xn/9dS1ZskSxsbHyer3yer2qqakJHjNnzhwtWrQo+P7hhx/WqlWr9LOf/Uyff/65fvKTn2jr1q1asGBB111FF7PbrIpxNAcSbtUAANC9OhVGXnjhBfl8Pt14441KTU0Nvt58883gMQUFBSouLg6+nzp1qpYsWaKXX35Z48aN0x//+EctX7683UGvPQHjRgAACI1OjRnpyJIk69atO2/bnXfeqTvvvLMzX2U6V1SEisprCCMAAHQzRma2wR3FbRoAAEKBMNKGhNNLwp/gYXkAAHQrwkgbPK4oSVKJv9bkSgAA6NsII21Ii3NKkop9hBEAALoTYaQNaXHNPSN7iv0dGrgLAAAuDWGkDdOGJslutehAaaWOltVc/AQAAHBJCCNtcEdFyONuvlVzvJJBrAAAdBfCSDsSY5qfj3Oyst7kSgAA6LsII+1IOj299yQ9IwAAdBvCSDsSY06HkSp6RgAA6C6EkXZwmwYAgO5HGGlHYsttmipu0wAA0F0II+0I3qahZwQAgG5DGGlHYr/m2zQnGMAKAEC3IYy0o39scxg5zsPyAADoNoSRdqSeXvTsZFW9ahuaTK4GAIC+iTDSDndUhKIibJIkLw/MAwCgWxBG2mGxWJR6+um9x3w8nwYAgO5AGLmINHfz03uLy+kZAQCgOxBGLqJl3EgxPSMAAHQLwshFpMY194wcY8wIAADdgjByEWktPSPl9IwAANAdCCMX0dIzUkzPCAAA3YIwchEtPSPH6BkBAKBbEEYuoqVnxF/bqKq6RpOrAQCg7yGMXESMw65Yp10SM2oAAOgOhJEOaFlr5BhrjQAA0OUIIx3QsgorPSMAAHQ9wkgHpNIzAgBAtyGMdEAaq7ACANBtCCMd0DKjpvAUYQQAgK5GGOmAEZ5YSdLuYz4ZhmFyNQAA9C2EkQ4Y7omVM8Iqf22jtheUmV0OAAB9CmGkAyJsVt06yiNJWvNZqcnVAADQtxBGOmh8RpwkaV9JpbmFAADQxxBGOujKlOZxIwdKK0yuBACAvoUw0kHp8dGSmp/eyyBWAAC6DmGkg5JdDklSXWNA5dUNJlcDAEDfQRjpIGeETYn9IiVJx1j8DACALkMY6YRhKTGSpO0F5eYWAgBAH0IY6YSsIUmSpB2sNQIAQJchjHTCkP79JEmFp6pNrgQAgL6DMNIJmQnNM2qOnCSMAADQVQgjnTC4fz9ZLFJpRZ1KK2rNLgcAgD6BMNIJLmeERnhckqQ/5R0zuRoAAPoGwkgnfe3qAZKkD/aUmFwJAAB9A2GkkyZkxkmSispYawQAgK5AGOmktLgoSZLXX6umAMvCAwBwuQgjnZQc65TdalFTwFCJn0GsAABcLsJIJ9msFnncTknSsXJu1QAAcLkII5dgwOlbNUWEEQAALlunw8iGDRt0++23Ky0tTRaLRcuXL2/3+HXr1slisZz38nq9l1qz6dLjmxc/++J4lcmVAADQ+3U6jFRVVWncuHF6/vnnO3Xe3r17VVxcHHwlJyd39qt7jJYZNRsPnDC3EAAA+gB7Z0+YOXOmZs6c2ekvSk5OVlxcXKfP64luHN5fNqtFW4+U6Yvjlbqif4zZJQEA0GuFbMzI+PHjlZqaqltuuUV/+9vf2j22rq5Ofr+/1asnSY+P1viMOEnSn3eyEisAAJej28NIamqqXnzxRb399tt6++23lZGRoRtvvFHbt29v85ycnBy53e7gKyMjo7vL7LRhyc29Ib/4cL8CrDcCAMAlsxiGccm/pBaLRcuWLdPs2bM7dd4NN9ygzMxM/e53v7vg/rq6OtXV1QXf+/1+ZWRkyOfzyeVyXWq5XWrjgRP65m82SZLWfO8GbtUAAHAOv98vt9t90d9vU6b2Tp48WQcOHGhzv8PhkMvlavXqaaYOTdLEgfGSpF1HfSZXAwBA72VKGMnLy1NqaqoZX92lhntiJUkHSitNrgQAgN6r07NpKisrW/VqHDp0SHl5eUpISFBmZqYWLVqkoqIivfbaa5KkX/ziFxo8eLBGjRql2tpa/eY3v9HatWv1wQcfdN1VmGTo6Vsz+0srTK4EAIDeq9NhZOvWrbrpppuC7x955BFJ0ty5c7V48WIVFxeroKAguL++vl7f+973VFRUpOjoaI0dO1Yffvhhq8/orYalNIcRekYAALh0lzWANVQ6OgAm1Ly+Wk3JWSOb1aI9T8+Qw24zuyQAAHqMHj2Ata9IcTmUFONQU8DQm1sKzS4HAIBeiTByGSwWi6YNTZQkLd542NxiAADopQgjl+n704dLkg4er1Kpv9bkagAA6H0II5cpIyFao9Ka74N9cuiUydUAAND7EEa6wJQhzbdqcr/gKb4AAHQWYaQLXDcsSZK0Yd8J9YLJSQAA9CiEkS4wZXCiIu1WFZXXaMN+ekcAAOgMwkgXiIq06fph/SVJL2/4wuRqAADoXQgjXeRfbh4qSfrk4CmVMKsGAIAOI4x0kdED3BqaHKOmgKGX1h80uxwAAHoNwkgXsVktevTWEZKk13IPs+YIAAAdRBjpQtlXJWtQYrQaA4Z2HvWZXQ4AAL0CYaQLWSwWXZ0ZL0l679NjJlcDAEDvQBjpYrePS5MkrdzlVW1Dk8nVAADQ8xFGutiNw/urf6xD9U0BbTnM8vAAAFwMYaSLWSwW3TIyRZL0pzxu1QAAcDGEkW5w8/BkSdKuIgaxAgBwMYSRbjDy9FN895dW6lRVvcnVAADQsxFGukFaXJTGDHCrKWDo6n9brWJfjdklAQDQYxFGusn91w0K/vmZ9z4zrxAAAHo4wkg3uWPcAE0elCBJem9XsQIBw+SKAADomQgj3cRqteiNb09RVIRNhiH9au1+s0sCAKBHIox0I5vVIrvNIkn6xYf7VXCy2uSKAADoeQgj3exfbh4a/PM/vLhRhsHtGgAAzkYY6Wb3Txus6EibJKm0ok7bjpSZXBEAAD0LYaSb2W1WvfiPE4Pv/+l32+gdAQDgLISRELj+yv66b9ogSdLJqnr9as0BcwsCAKAHIYyEyJOzRgb//F8f7lN9Y8DEagAA6DkIIyFitVr01rys4Pt7X91sYjUAAPQchJEQumZQgjwupyRp4xcn9VruYXMLAgCgByCMhNjrD04O/vmpd3frpytZKh4AEN4IIyE2NDlWEwfGB9+/vOGg6hqbTKwIAABzEUZM8NK3JirWaQ++P8LKrACAMEYYMUFSjEO7fjJDowe4JEnT/2uDDp+oMrkqAADMQRgx0Zwpg4J//vpLuWriyb4AgDBEGDHR16/JCPaOlFbU6YrHV2r3MZ/JVQEAEFqEEZP9cd5UDU2OCb6f9/o21TYwoBUAED4IIyZzRti0fP40xTiaB7QWnqrRd17fZnJVAACEDmGkB4hx2JX/f2bosZkjJEkf7T2uQY+9p6NlzLIBAPR9hJEeZN4NVyj7quTg++v+/SOe8AsA6PMIIz3M/JuGtnr/bt4xkyoBACA0CCM9zITMeOV8dUzw/cI38/TUu/kKMO0XANBHEUZ6oLsnZ+rb1w8Jvn8t94g+2FNiYkUAAHQfwkgP9eitIzQna2Dw/Tvbj7IoGgCgTyKM9FA2q0U/uX2UXrl3kiwW6YM9JRr51Cq9s/2oNh08aXZ5AAB0GcJID2a1WnTziBTdO3WQJKmuMaBH/rBTd738id7NKzK3OAAAughhpBd4ctZIjRngbrXt4aV5mvc7FkcDAPR+hJFewGq16K15Wa3GkEjSqt1e5RfxLBsAQO9GGOklnBE2PX3HaH36k+mttv/9rz/WtiNlJlUFAMDlI4z0Mi5nhD57+tZW21buKjapGgAALl+nw8iGDRt0++23Ky0tTRaLRcuXL7/oOevWrdPVV18th8OhoUOHavHixZdQKlpERdr01rys4PvffnxI33l9mz76vFT/tXqfqusbTawOAIDO6XQYqaqq0rhx4/T888936PhDhw5p1qxZuummm5SXl6eFCxfqwQcf1Pvvv9/pYnHGNYMSdPCnt+nWUR5J0l/yvbpv8Rb9cs1+jXzqfWX/fL2KymtMrhIAgIuzGJfxJDaLxaJly5Zp9uzZbR7z6KOP6r333lN+fn5w2ze+8Q2Vl5dr1apVHfoev98vt9stn88nl8t1qeX2SYZh6P3dJZr3+oVn1nx/+pVacPOwEFcFAEDHf7+7fcxIbm6usrOzW22bMWOGcnNz2zynrq5Ofr+/1QsXZrFYdOtoj96al6V5N1xx3v7//GCfHl+2S4Wnqk2oDgCAi+v2MOL1epWSktJqW0pKivx+v2pqLnwbIScnR263O/jKyMjo7jJ7vWsGJeixmSP0wXev119/eJMevG5wcN+STQW688Vc5RWWy+urNbFKAADO1yNn0yxatEg+ny/4KiwsNLukXuPKlFhlJETr8duu0v/58qjgdq+/VrOf/5um5KzRqnyvLuPuHAAAXarbw4jH41FJSesnzpaUlMjlcikqKuqC5zgcDrlcrlYvdI7VatHcqYP0xU9v0/enX9lq37zXt2nwopX6885jJlUHAMAZ3R5GsrKytGbNmlbbVq9eraysrDbOQFeyWS164LohGpgYfd6+f3ljh37x4T6t21uqvd4KE6oDAECyd/aEyspKHThwIPj+0KFDysvLU0JCgjIzM7Vo0SIVFRXptddekyTNmzdP//3f/60f/vCHuv/++7V27Vr94Q9/0Hvvvdd1V4F2RUXa9Od/uU6+6gbNfXWzDh6vCu77xYf7g39etfDvNMLjUuGpakVH2pQY4zCjXABAmOn01N5169bppptuOm/73LlztXjxYt177706fPiw1q1b1+qc7373u9qzZ4/S09P15JNP6t577+3wdzK1t+vt9Vbo27/bqiMnW8+yGeGJ1eene0l2/ni63FERZpQHAOgDOvr7fVnrjIQKYaR71DU26Xe5R/TB7hJtPnzqvP1x0RH66w9vUrGvVu/mFWn+TUMVHdnpzjQAQJgijKBT/t/Gw3p6xR41Bdr+6zD/pis0c3SqhibHyBlhC2F1AIDeiDCCTqttaJJhSG9vP6ofLc9v99hl/zxVEzLjJTWvAtsUMGS39ciZ4gAAk3T095s+dwS19Hbcc22mPvf69fonBW0e+5X/2SiPy6nbxqTqZFWdNuw7rlULr1eKyxmqcgEAfQQ9I+iQsqp6/XLNfi3eeLjNY4b076el356i5FgCCQCA2zToBoGAoT9sLdTijYeV7HJqw77jFzzuiduu0uCkfnp90xEdLavRGw9NUf9YpgkDQLghjCAkBj3WsfVifjTrKn3z2kxm4wBAGCGMICR81Q266+Vcfe6t0Nh0t6YMSdTLGw62eXysw667r83U96cPV6TdqqaAIaul+enDAIC+hTCCkKpvDEiSIu1WVdQ26O1tR/WTP+9p95xIu1X1jQHdO3WQvnRVsq4ZlMCUYQDoQwgj6DF+uvKzdntLznb35EztKCjTj2aNlCvKrjED3PSaAEAvRRhBjxEIGDpRVafkWKd2Fpbrza2FSoiO1H9/dODiJ0t6+VsTNT4zjlk6ANDLEEbQ4/mqG/RfH+7TriKf9pdUyF/b2O7xt4xMUVJMpL4+KUPj0uNktdJjAgA9GWEEvYphGPreWzv1zvaiDh3vsFv1+G1Xqay6Xn/eeUxfvTpd35ycqfh+kd1cKQCgowgj6JW2F5SpX6Rdwz2x2njghIrKa/TyhoPaX1rZofMnDYzXM18ZoxOVdXpi2S7NGO3RD6YPl9VioScFAEKMMII+ZUdBmTISorUq33vR5+a05T++Nlb9Yx1avPGw7rk2U9NHebq4SgDA2Qgj6LPKq+u1cpdXAxOjtWRzgVJinTp4olLr9l54Rdi2PPvVMRo9wK0BcVHaXlCmd/OO6SdfHqV+DptK/XXKSIjupisAgPBAGEFY2nzolEoratUUMPSrNfslSV8cr+rw+V+dMED+2kZ9+FmJfjBjuL45OVM7Css0eXCiYhysHgsAnUEYAU47fKJKq3Z7tavIp7/uO37RWTsXkpkQrX+YmK5hyTFqDBganxFHzwkAXARhBGhDU8DQf36wV5sOntSoNLf+duCE/n5sqsqqG/S7T450+HMemzlCdQ0BNTQFdOtoj46V12jt56V6OHuYUt1RKvHXKtZp53k8AMIWYQToJMMw9OFnpZKkCJtFf91/QjEOu/647aiKymsu+XP/6foh8tc2qKa+ST++fVRw+nFDU0B2q4UVZgH0WYQRoIuU+mv19vYivbDugDITo/UPV6cr/5hff9x2tEs+/wczhispJlIffX5cT8y6Sikup0orapUez20gAL0bYQToYoGAIUOS7fR6JYZhaH9ppTbsO67/+95nkqQ7J6arpKJOG/Z1bmZPWyLtVj1yy5XaerhMh05UKntkinYUlGvL4VP6/YPXauoVSV3yPQDQHQgjQIg1BYxgUGl5X13fqNqGgHL+8pn+vPOYYp0RGuGJ1Zh0t2rqm/RabsfHqFxIqtup4xV1evqO0SqrrlfhqWotzL5SHjfP8QFgPsII0AtU1Dao2FerLYdP6Ylll7aY24XMGJWi1XtKFDCk7KuS9R//ME4Ou1Wv/u2Qbh3t0dDk2FbH7yup0MDEaDnsNtU2NMkZYeuyWgCEL8II0IsVnKxWdUOjUmKd2l5Qpsq6Rk0bmqRbfr5ekXarXv7WJP3jbzep4hKmKUvS3w1L0qdHfaptaFJdY0CS1C/SpnumDNRv/npQY9Lj9N3sYfLVNCjGYdeXrkqRJB0rr1FiTKQibVYG3gK4KMII0AdV1zeqMWDI5YxQdX2jDENyRtj09vajamwy5PXXBhd760pThiTo/mmD9e3fbQtum3pFor59/RCt33dcf8o7ptQ4p5796liNSnMRVABIIowAYauxKaCDJ6qUmRCt7QVl+ssur0oravX+7hJJ0lN/P1L7Syv0xubCbqthWHKMyk/3qqS6naqsa5Qzwqb9JRX6yoR0zZ6QpqHJMazBAvRxhBEA7aptaJLFIjnsNgUChkor6vThZyW6MiVWK3cVa/HGwyGr5Z+uH6JvTM5Uv0ibtheUacqQRMVFRwb3F56q1hfHK3Xj8OTgNsMw1Bgw1BQwGOMC9FCEEQCXpbahSXarRTarRQ+9tk0fflYim9Wiof1jtLekQt+8NlP3TxusVLdTh09W6fFl+dpZWN5l3x8XHaHy6gbFOu3BsTEup11j0+P0WbFfVadnKknS0OQYfWXCAN09OVMup13r9x3X6AFuxTrtioqwyV/TKJvNorKqepbxB0KIMAKgywQCzb0QkXZru8c1BQzVNTYpOtKuhqaADp2o0uZDp+RxOXVFcoyOV9Tp31d9rm1HypSZEK3hnlit3lMSoqtoNiw5RjPHpMoi6fork7Sz0Kei8hpNGhiv0oo6fXlcWnCV3BaNTQFV1TXJHR0R0lqB3o4wAqDXqKxr1Iqdx/TYO7skSbeN8WjlLq9GeGJ1ZUqsNn5xQicq60NWzw9mDNcnB08q94uTagyc+Sdy0sB4DUuJ0fRRHt14ZX9ZLBadqKxTfpFPU69IumhYA8INYQRAr3P4RJX6OezqH+s4b99L67/QSxsO6pV7r1F+kU9XpsTKX9OgiQPjVXCqWier6rTlcJkOHa/S1QPjVNcQ0C/X7Ff2VSnaV1KhgyeqQnYdYwa4VdfYpH0llRrhidVTfz9SH+wp0d8OnFBUpE0zRnlks1rUFDD0nRuukNVqkWEYrWYhlVXVa29JhdxREboqlX/30DsRRgDgNMMwtL2gXHHREfpd7hFNGhSv6Eib3thcqJr6JiW7HJqQEadIu1XvbC9SXmG5HHar/Je4jktnDUyM1pGT1ZKkqAibahqaLnjcq/ddo6sz47W7yKfjlXX647ajckbYNDAhWvdfN1hpcVEXPK+ovEZRETYlnHP7CehuhBEAuETn9lIYhqG9JRU6Vl4jX02DImxW7Trq00sbDgaPcUdFyGa16FRV6G4ntSU51qGy6nplJkTrxuHJ+u3Hh5TmduqfbrhCIzyxemNzgWobAnro+iEam+6W3WrRyl1eDUqK1qg0t46WVcsdFaFY5/ljZMqq6hUXHcFaMugQwggAhMDnXr8GJ/WTw948vbjYV6OdheVauqVQg5P6aV9JhXYW+nTt4AT9/Ovj9dcDx7VgyQ5JzQvHfW/6lTpeUa+TVXV6cnm+Auf8i3xVqkuDEqP1l3xvqC9NkTarrhkcr1tHebTp0Cmt+LRYkpQeH6XZ4wdoy+FT2nm0XFOvSNLNI5J1+7g0BQKGAoahYl+thntiVd8YUD8H68mEK8IIAPRgR05WKTnWqajI89dI2X3MJ4/LqcSYM2NnNh08qWdXfa4TlXX6h6sz5HE79OjbuzQ0OUYHSivP+4yWqdE9xRO3XaVfr90fvPU1JKmfvnltpry+Wm3Yf1z7SiqVmRCt+6cNUny/SI1Kc8kVFaHk2OaHPuYVlistzqlYR4TsNouee3+vpgxJ0PXD+suQFGFj8HBPRBgBgDBRWdeoytpGedxOBQLNjwVoGT/i9dUqoV+kPtpbKl9Ng/w1DRrSv58KT9Xox3/aLUn615uHqqKuUX87cEL7Ss4PNr3RCE+spo/yyGG3qqyqXrMnDNDIVJesVos2HzolSbpmUHyr202+6gbJ0nzLrSlgyGrRebfruD3VOYQRAEC76hqbZLNYZD+nV6HgZLWKymuUdUWiDMPQu3nHlOxyaOoVSTIMQ/6aRn3/jzsVYbPoJ7ePkt1m1ZtbCuWwW3XoRJX2FPu17UiZEvtF6uQFxtCkx0fJGWFTwalq1Z9+UKNZIu3W82roF2lTVX2TBif104KbhurIySqlx0fr1x/tl8fl1DWDEhQXHaHbxqTKbrXql2v26dOjPs0am6pSf52+dnW6PG6nkmKaBwxbLBbVNwbCcuo3YQQAYBpfdYNsNoucdqtqGwMqKqtRcqzjvAXlfNUNqmtqktdXq/T4aO0+5tMftx3V2PQ4bTxwQms+Lw0eO/WKRG384mSoL6XLZCRE6ebhyTpaVqMTlXXaedQnSbpzYroq6xrVP9ahytpG7Sut0JgBcSr118pus2jy4ETVNwaUHh+lVLdTQ5Nj1M9h115vhQYl9dP3/7BTCTGR+taUgfr0aLm+MiFdETaLKuoaZZEU64xQfpFPAxOjLzgoOa+wXIn9IrtldWLCCACg1ztRWacYhz34/KGqukbZrBbVNwX0hy2FGpnm0oSMeAUMQ58cPBkMOxnx0cHF9DxupypqG/X0ij2tPntUmks2q0Wfng4FZ3NGWDV2QJyKymtUVF7T/RcaAjarRRnxUaqub1LAaG7bs/3PPVfrtjGpXfqdhBEAAM5S29CkCJtVNuuFx334ahrU2BRoNXC4xb6SCg2IiwqGIqtFOlVVr73eCg2Ij1Lm6V6Fo2U1WrBke7DXQzpz2+dszghr8NlKPcX2J2/p8rVoCCMAAJissSkgu82q0opa+WsaNTQ5Rr6aBsU67Kqqb1SEzSpnhC04OHbbkTIt31GkkWku3TY6Vccr6/T/Nh6Wx+1UXmG5Vu8pUdaQRJX4azVpULyOV9QpoZ9DkXar3thc0G4tHpdTXn/tBff9y81D9b3pw7v8+gkjAAD0IfWNAdU2Nsl1gXEf56ptaJLFItXUN+lYea1GprX+7fx4/wntL63QqDS3Jg9O6K6SO/z7zUo0AAD0ApF2a4dn5LTcTnLYbYqLPv/Wy3XDknTdsKQure9yhN88IwAA0KMQRgAAgKkIIwAAwFSEEQAAYCrCCAAAMBVhBAAAmIowAgAATHVJYeT555/XoEGD5HQ6de2112rz5s1tHrt48WJZLJZWL6fTeckFAwCAvqXTYeTNN9/UI488oh//+Mfavn27xo0bpxkzZqi0tLTNc1wul4qLi4OvI0eOXFbRAACg7+h0GPn5z3+uhx56SPfdd59GjhypF198UdHR0XrllVfaPMdiscjj8QRfKSkpl1U0AADoOzoVRurr67Vt2zZlZ2ef+QCrVdnZ2crNzW3zvMrKSg0cOFAZGRm64447tHv37na/p66uTn6/v9ULAAD0TZ0KIydOnFBTU9N5PRspKSnyer0XPGf48OF65ZVX9O677+r1119XIBDQ1KlTdfTo0Ta/JycnR263O/jKyMjoTJkAAKAX6fbZNFlZWZozZ47Gjx+vG264Qe+884769++vl156qc1zFi1aJJ/PF3wVFhZ2d5kAAMAknXpqb1JSkmw2m0pKSlptLykpkcfj6dBnREREaMKECTpw4ECbxzgcDjkcjuB7wzAkids1AAD0Ii2/2y2/423pVBiJjIzUxIkTtWbNGs2ePVuSFAgEtGbNGi1YsKBDn9HU1KRdu3bptttu6/D3VlRUSBK3awAA6IUqKirkdrvb3N+pMCJJjzzyiObOnatJkyZp8uTJ+sUvfqGqqirdd999kqQ5c+ZowIABysnJkSQ9/fTTmjJlioYOHary8nI999xzOnLkiB588MEOf2daWpoKCwsVGxsri8XS2ZLb5Pf7lZGRocLCQrlcri773L6ItuoY2qnjaKuOoZ06jrbqmFC2k2EYqqioUFpaWrvHdTqM3HXXXTp+/Lieeuopeb1ejR8/XqtWrQoOai0oKJDVemYoSllZmR566CF5vV7Fx8dr4sSJ2rhxo0aOHNnh77RarUpPT+9sqR3mcrn4i9tBtFXH0E4dR1t1DO3UcbRVx4SqndrrEWlhMS52I6cP8/v9crvd8vl8/MW9CNqqY2injqOtOoZ26jjaqmN6YjvxbBoAAGCqsA4jDodDP/7xj1vN3MGF0VYdQzt1HG3VMbRTx9FWHdMT2ymsb9MAAADzhXXPCAAAMB9hBAAAmIowAgAATEUYAQAApgrrMPL8889r0KBBcjqduvbaa7V582azSwqpnJwcXXPNNYqNjVVycrJmz56tvXv3tjqmtrZW8+fPV2JiomJiYvS1r33tvGcTFRQUaNasWYqOjlZycrJ+8IMfqLGxMZSXElLPPvusLBaLFi5cGNxGOzUrKirSP/7jPyoxMVFRUVEaM2aMtm7dGtxvGIaeeuoppaamKioqStnZ2dq/f3+rzzh16pTuueceuVwuxcXF6YEHHlBlZWWoL6VbNTU16cknn9TgwYMVFRWlK664Qv/2b//W6vkd4dpWGzZs0O233660tDRZLBYtX7681f6uapdPP/1Uf/d3fyen06mMjAz9x3/8R3dfWpdqr50aGhr06KOPasyYMerXr5/S0tI0Z84cHTt2rNVn9Kh2MsLU0qVLjcjISOOVV14xdu/ebTz00ENGXFycUVJSYnZpITNjxgzj1VdfNfLz8428vDzjtttuMzIzM43KysrgMfPmzTMyMjKMNWvWGFu3bjWmTJliTJ06Nbi/sbHRGD16tJGdnW3s2LHDWLlypZGUlGQsWrTIjEvqdps3bzYGDRpkjB071nj44YeD22knwzh16pQxcOBA49577zU2bdpkHDx40Hj//feNAwcOBI959tlnDbfbbSxfvtzYuXOn8eUvf9kYPHiwUVNTEzzm1ltvNcaNG2d88sknxl//+ldj6NChxt13323GJXWbZ555xkhMTDRWrFhhHDp0yHjrrbeMmJgY45e//GXwmHBtq5UrVxpPPPGE8c477xiSjGXLlrXa3xXt4vP5jJSUFOOee+4x8vPzjTfeeMOIiooyXnrppVBd5mVrr53Ky8uN7Oxs48033zQ+//xzIzc315g8ebIxceLEVp/Rk9opbMPI5MmTjfnz5wffNzU1GWlpaUZOTo6JVZmrtLTUkGSsX7/eMIzmv9ARERHGW2+9FTzms88+MyQZubm5hmE0/w9htVoNr9cbPOaFF14wXC6XUVdXF9oL6GYVFRXGsGHDjNWrVxs33HBDMIzQTs0effRR47rrrmtzfyAQMDwej/Hcc88Ft5WXlxsOh8N44403DMMwjD179hiSjC1btgSP+ctf/mJYLBajqKio+4oPsVmzZhn3339/q21f/epXjXvuuccwDNqqxbk/sl3VLv/zP/9jxMfHt/p/79FHHzWGDx/ezVfUPS4U2s61efNmQ5Jx5MgRwzB6XjuF5W2a+vp6bdu2TdnZ2cFtVqtV2dnZys3NNbEyc/l8PklSQkKCJGnbtm1qaGho1U4jRoxQZmZmsJ1yc3M1ZsyY4LOJJGnGjBny+/3avXt3CKvvfvPnz9esWbNatYdEO7X405/+pEmTJunOO+9UcnKyJkyYoP/93/8N7j906JC8Xm+rdnK73br22mtbtVNcXJwmTZoUPCY7O1tWq1WbNm0K3cV0s6lTp2rNmjXat2+fJGnnzp36+OOPNXPmTEm0VVu6ql1yc3N1/fXXKzIyMnjMjBkztHfvXpWVlYXoakLL5/PJYrEoLi5OUs9rp04/KK8vOHHihJqamlr9MEhSSkqKPv/8c5OqMlcgENDChQs1bdo0jR49WpLk9XoVGRkZ/MvbIiUlRV6vN3jMhdqxZV9fsXTpUm3fvl1btmw5bx/t1OzgwYN64YUX9Mgjj+jxxx/Xli1b9K//+q+KjIzU3Llzg9d5oXY4u52Sk5Nb7bfb7UpISOgz7SRJjz32mPx+v0aMGCGbzaampiY988wzuueeeySJtmpDV7WL1+vV4MGDz/uMln3x8fHdUr9Zamtr9eijj+ruu+8OPoump7VTWIYRnG/+/PnKz8/Xxx9/bHYpPU5hYaEefvhhrV69Wk6n0+xyeqxAIKBJkybppz/9qSRpwoQJys/P14svvqi5c+eaXF3P8oc//EG///3vtWTJEo0aNUp5eXlauHCh0tLSaCt0qYaGBn3961+XYRh64YUXzC6nTWF5myYpKUk2m+282Q4lJSXyeDwmVWWeBQsWaMWKFfroo4+Unp4e3O7xeFRfX6/y8vJWx5/dTh6P54Lt2LKvL9i2bZtKS0t19dVXy263y263a/369frVr34lu92ulJQU2klSamqqRo4c2WrbVVddpYKCAklnrrO9/+88Ho9KS0tb7W9sbNSpU6f6TDtJ0g9+8AM99thj+sY3vqExY8boW9/6lr773e8qJydHEm3Vlq5ql3D4/1E6E0SOHDmi1atXt3pCb09rp7AMI5GRkZo4caLWrFkT3BYIBLRmzRplZWWZWFloGYahBQsWaNmyZVq7du153XETJ05UREREq3bau3evCgoKgu2UlZWlXbt2tfpL3fKX/twfpt7qS1/6knbt2qW8vLzga9KkSbrnnnuCf6adpGnTpp03NXzfvn0aOHCgJGnw4MHyeDyt2snv92vTpk2t2qm8vFzbtm0LHrN27VoFAgFde+21IbiK0KiurpbV2vqfX5vNpkAgIIm2aktXtUtWVpY2bNighoaG4DGrV6/W8OHD+8wtmpYgsn//fn344YdKTExstb/HtVOXD4ntJZYuXWo4HA5j8eLFxp49e4xvf/vbRlxcXKvZDn3dd77zHcPtdhvr1q0ziouLg6/q6urgMfPmzTMyMzONtWvXGlu3bjWysrKMrKys4P6WKavTp0838vLyjFWrVhn9+/fvU1NWL+Ts2TSGQTsZRvNofbvdbjzzzDPG/v37jd///vdGdHS08frrrwePefbZZ424uDjj3XffNT799FPjjjvuuOC0zAkTJhibNm0yPv74Y2PYsGG9frrquebOnWsMGDAgOLX3nXfeMZKSkowf/vCHwWPCta0qKiqMHTt2GDt27DAkGT//+c+NHTt2BGeBdEW7lJeXGykpKca3vvUtIz8/31i6dKkRHR3dq6b2ttdO9fX1xpe//GUjPT3dyMvLa/Xv+9kzY3pSO4VtGDEMw/j1r39tZGZmGpGRkcbkyZONTz75xOySQkrSBV+vvvpq8Jiamhrjn//5n434+HgjOjra+MpXvmIUFxe3+pzDhw8bM2fONKKiooykpCTje9/7ntHQ0BDiqwmtc8MI7dTsz3/+szF69GjD4XAYI0aMMF5++eVW+wOBgPHkk08aKSkphsPhML70pS8Ze/fubXXMyZMnjbvvvtuIiYkxXC6Xcd999xkVFRWhvIxu5/f7jYcfftjIzMw0nE6nMWTIEOOJJ55o9UMRrm310UcfXfDfpblz5xqG0XXtsnPnTuO6664zHA6HMWDAAOPZZ58N1SV2ifba6dChQ23++/7RRx8FP6MntZPFMM5a8g8AACDEwnLMCAAA6DkIIwAAwFSEEQAAYCrCCAAAMBVhBAAAmIowAgAATEUYAQAApiKMAAAAUxFGAACAqQgjAADAVIQRAABgKsIIAAAw1f8H4ANy06ZvHyQAAAAASUVORK5CYII=\n" }, "metadata": {} } ], "source": [ "plt.plot(torch.tensor(l).view(-1, 10).mean(1).numpy())" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "cPqZDUJ8I2AA", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "b602a445-1cd5-425e-dac1-6b2750ae0684" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "def _initialModel._to_java_impl():\n", " \"\"\"\n", " Deprecated in 2.3.0. Use :func:`pyspark.sql.types.DataType`, int or :class:`Column` expression in the given key (default param).\n", "\n", " >>> df = spark.range(1, 0).alias('age')).collect()\n", " [Row(name=u'Alice', age=1, name=u'Alice')]\n", " \"\n" ] } ], "source": [ "begin_text = torch.tensor(tok.encode('def '), device=device).unsqueeze(0)\n", "print(''.join(tok.decode(generate_batch(model, begin_text))))" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "V100", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 0 }