{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "48HjPXSxsiSO", "outputId": "1d36f06e-16a4-42bc-eba7-dc7d249589d7" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader\n", "from datasets import load_dataset\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "\n", "torch.manual_seed(12046)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "Hn9ypPW0siSP" }, "outputs": [], "source": [ "# 一些超参数\n", "context_length = 10\n", "learning_rate = 0.01\n", "eval_iters = 10\n", "batch_size=1000\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "P_-nzG89siSQ", "outputId": "4329e089-bae9-417b-e67f-7cef29c9d1d4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "def to_arrow_schema(schema):\n", " \"\"\" Convert a schema from Spark to Arrow\n", " \"\"\"\n", " import pyarrow as pa\n", " fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)\n", " for field in schema]\n", " return pa.schema(fields)\n", "['def to_arrow_schema(schema):\\n \"\"\" Convert a schema from Spark to Arrow\\n \"\"\"\\n import pyarrow as pa\\n fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)\\n for field in schema]\\n return pa.schema(fields)', 'def from_arrow_type(at):\\n \"\"\" Convert pyarrow type to Spark data type.\\n \"\"\"\\n import pyarrow.types as types\\n if types.is_boolean(at):\\n spark_type = BooleanType()\\n elif types.is_int8(at):\\n spark_type = ByteType()\\n elif types.is_int16(at):\\n spark_type = ShortType()\\n elif types.is_int32(at):\\n spark_type = IntegerType()\\n elif types.is_int64(at):\\n spark_type = LongType()\\n elif types.is_float32(at):\\n spark_type = FloatType()\\n elif types.is_float64(at):\\n spark_type = DoubleType()\\n elif types.is_decimal(at):\\n spark_type = DecimalType(precision=at.precision, scale=at.scale)\\n elif types.is_string(at):\\n spark_type = StringType()\\n elif types.is_binary(at):\\n spark_type = BinaryType()\\n elif types.is_date32(at):\\n spark_type = DateType()\\n elif types.is_timestamp(at):\\n spark_type = TimestampType()\\n elif types.is_list(at):\\n if types.is_timestamp(at.value_type):\\n raise TypeError(\"Unsupported type in conversion from Arrow: \" + str(at))\\n spark_type = ArrayType(from_arrow_type(at.value_type))\\n elif types.is_struct(at):\\n if any(types.is_struct(field.type) for field in at):\\n raise TypeError(\"Nested StructType not supported in conversion from Arrow: \" + str(at))\\n return StructType(\\n [StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)\\n for field in at])\\n else:\\n raise TypeError(\"Unsupported type in conversion from Arrow: \" + str(at))\\n return spark_type']\n" ] } ], "source": [ "raw_datasets = load_dataset('code_search_net', 'python')\n", "datasets = raw_datasets['train'].filter(lambda x: 'apache/spark' in x['repository_name'])\n", "# 通过索引提取datasets数据的时候,返回一个dict,其中的value是一个字符串\n", "print(datasets[8]['whole_func_string'])\n", "# 当传入的是一个数组时,返回的依然是一个dict,但其中的value是一个列表\n", "print(datasets[8: 10]['whole_func_string'])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "Ky6gh_4TsiSQ" }, "outputs": [], "source": [ "class char_tokenizer:\n", "\n", " def __init__(self, data, begin_ind=0, end_ind=1):\n", " # 数据中出现的所有字符构成字典\n", " chars = sorted(list(set(''.join(data))))\n", " # 预留两个位置给开头和结尾的特殊字符\n", " self.char2ind = {s : i + 2 for i, s in enumerate(chars)}\n", " self.char2ind['<|b|>'] = begin_ind\n", " self.char2ind['<|e|>'] = end_ind\n", " self.begin_ind = begin_ind\n", " self.end_ind = end_ind\n", " self.ind2char = {i : s for s, i in self.char2ind.items()}\n", "\n", " def encode(self, text):\n", " '''\n", " 编码\n", " 参数\n", " ----\n", " text :str,文本\n", " '''\n", " return [self.char2ind[c] for c in text]\n", "\n", " def decode(self, enc):\n", " '''\n", " 解码\n", " 参数\n", " ----\n", " enc :int or list[int]\n", " '''\n", " if isinstance(enc, int):\n", " return self.ind2char[enc]\n", " return [self.ind2char[i] for i in enc]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "yDS5elR7siSR", "outputId": "5c2e18d7-3f81-4873-a14d-c8a7b900a0d1" }, "outputs": [ { "data": { "text/plain": [ "('def postappend(self):', 99)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 举例验证分词器\n", "tok = char_tokenizer(datasets['whole_func_string'])\n", "example_text = 'def postappend(self):'\n", "''.join(tok.decode(tok.encode(example_text))), len(tok.char2ind)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "egLVxc1VsiSR", "outputId": "5d0b0760-2267-47d8-93f1-9bb326f25719" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<|b|><|b|><|b|><|b|><|b|><|b|><|b|><|b|><|b|><|b|> ---> d\n", "<|b|><|b|><|b|><|b|><|b|><|b|><|b|><|b|><|b|>d ---> e\n", "<|b|><|b|><|b|><|b|><|b|><|b|><|b|><|b|>de ---> f\n", "<|b|><|b|><|b|><|b|><|b|><|b|><|b|>def ---> \n", "<|b|><|b|><|b|><|b|><|b|><|b|>def ---> p\n", "<|b|><|b|><|b|><|b|><|b|>def p ---> o\n", "<|b|><|b|><|b|><|b|>def po ---> s\n", "<|b|><|b|><|b|>def pos ---> t\n", "<|b|><|b|>def post ---> a\n", "<|b|>def posta ---> p\n", "def postap ---> p\n", "ef postapp ---> e\n", "f postappe ---> n\n", " postappen ---> d\n", "postappend ---> (\n", "ostappend( ---> s\n", "stappend(s ---> e\n", "tappend(se ---> l\n", "append(sel ---> f\n", "ppend(self ---> )\n", "pend(self) ---> :\n", "end(self): ---> <|e|>\n" ] } ], "source": [ "def autoregressive_trans(text, tokenizer, context_length=context_length):\n", " '''\n", " 将文本转换成一系列的训练数据\n", " 参数\n", " ----\n", " text :str,文本\n", " tokenizer :分词器\n", " context_length :int,背景文本的长度\n", " 返回\n", " ----\n", " inputs :list[list[int]],背景文本(特征)\n", " labels :list[list[int]],预测标签\n", " '''\n", " inputs, labels = [], []\n", " b_ind = tokenizer.begin_ind\n", " e_ind = tokenizer.end_ind\n", " enc = tokenizer.encode(text)\n", " # 增加开始和结尾的特殊字符\n", " x = [b_ind] * context_length + enc + [e_ind]\n", " for i in range(len(x) - context_length):\n", " inputs.append(x[i: i + context_length])\n", " labels.append(x[i + context_length])\n", " return inputs, labels\n", "\n", "# 举例展示自回归模式的训练数据\n", "inputs, labels = autoregressive_trans(example_text, tok)\n", "for a, b in zip(inputs, labels):\n", " print(''.join(tok.decode(a)), '--->', tok.decode(b))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "7kOWxKCisiSR", "outputId": "82ab4276-dc85-42e2-c6a9-e2e99629928f" }, "outputs": [ { "data": { "text/plain": [ "{'inputs': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 71],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 71, 72],\n", " [0, 0, 0, 0, 0, 0, 0, 71, 72, 73],\n", " [0, 0, 0, 0, 0, 0, 71, 72, 73, 3],\n", " [0, 0, 0, 0, 0, 71, 72, 73, 3, 87],\n", " [0, 0, 0, 0, 71, 72, 73, 3, 87, 82],\n", " [0, 0, 0, 71, 72, 73, 3, 87, 82, 66],\n", " [0, 0, 71, 72, 73, 3, 87, 82, 66, 68],\n", " [0, 71, 72, 73, 3, 87, 82, 66, 68, 85],\n", " [71, 72, 73, 3, 87, 82, 66, 68, 85, 85],\n", " [72, 73, 3, 87, 82, 66, 68, 85, 85, 82],\n", " [73, 3, 87, 82, 66, 68, 85, 85, 82, 90],\n", " [3, 87, 82, 66, 68, 85, 85, 82, 90, 66],\n", " [87, 82, 66, 68, 85, 85, 82, 90, 66, 86],\n", " [82, 66, 68, 85, 85, 82, 90, 66, 86, 70],\n", " [66, 68, 85, 85, 82, 90, 66, 86, 70, 75],\n", " [68, 85, 85, 82, 90, 66, 86, 70, 75, 72],\n", " [85, 85, 82, 90, 66, 86, 70, 75, 72, 80],\n", " [85, 82, 90, 66, 86, 70, 75, 72, 80, 68],\n", " [82, 90, 66, 86, 70, 75, 72, 80, 68, 11],\n", " [90, 66, 86, 70, 75, 72, 80, 68, 11, 86],\n", " [66, 86, 70, 75, 72, 80, 68, 11, 86, 70],\n", " [86, 70, 75, 72, 80, 68, 11, 86, 70, 75],\n", " [70, 75, 72, 80, 68, 11, 86, 70, 75, 72],\n", " [75, 72, 80, 68, 11, 86, 70, 75, 72, 80],\n", " [72, 80, 68, 11, 86, 70, 75, 72, 80, 68],\n", " [80, 68, 11, 86, 70, 75, 72, 80, 68, 12],\n", " [68, 11, 86, 70, 75, 72, 80, 68, 12, 29],\n", " [11, 86, 70, 75, 72, 80, 68, 12, 29, 2],\n", " [86, 70, 75, 72, 80, 68, 12, 29, 2, 3],\n", " [70, 75, 72, 80, 68, 12, 29, 2, 3, 3],\n", " [75, 72, 80, 68, 12, 29, 2, 3, 3, 3],\n", " [72, 80, 68, 12, 29, 2, 3, 3, 3, 3],\n", " [80, 68, 12, 29, 2, 3, 3, 3, 3, 5],\n", " [68, 12, 29, 2, 3, 3, 3, 3, 5, 5],\n", " [12, 29, 2, 3, 3, 3, 3, 5, 5, 5],\n", " [29, 2, 3, 3, 3, 3, 5, 5, 5, 3],\n", " [2, 3, 3, 3, 3, 5, 5, 5, 3, 38],\n", " [3, 3, 3, 3, 5, 5, 5, 3, 38, 82],\n", " [3, 3, 3, 5, 5, 5, 3, 38, 82, 81],\n", " [3, 3, 5, 5, 5, 3, 38, 82, 81, 89],\n", " [3, 5, 5, 5, 3, 38, 82, 81, 89, 72],\n", " [5, 5, 5, 3, 38, 82, 81, 89, 72, 85],\n", " [5, 5, 3, 38, 82, 81, 89, 72, 85, 87],\n", " [5, 3, 38, 82, 81, 89, 72, 85, 87, 3],\n", " [3, 38, 82, 81, 89, 72, 85, 87, 3, 68],\n", " [38, 82, 81, 89, 72, 85, 87, 3, 68, 3],\n", " [82, 81, 89, 72, 85, 87, 3, 68, 3, 86],\n", " [81, 89, 72, 85, 87, 3, 68, 3, 86, 70],\n", " [89, 72, 85, 87, 3, 68, 3, 86, 70, 75],\n", " [72, 85, 87, 3, 68, 3, 86, 70, 75, 72],\n", " [85, 87, 3, 68, 3, 86, 70, 75, 72, 80],\n", " [87, 3, 68, 3, 86, 70, 75, 72, 80, 68],\n", " [3, 68, 3, 86, 70, 75, 72, 80, 68, 3],\n", " [68, 3, 86, 70, 75, 72, 80, 68, 3, 73],\n", " [3, 86, 70, 75, 72, 80, 68, 3, 73, 85],\n", " [86, 70, 75, 72, 80, 68, 3, 73, 85, 82],\n", " [70, 75, 72, 80, 68, 3, 73, 85, 82, 80],\n", " [75, 72, 80, 68, 3, 73, 85, 82, 80, 3],\n", " [72, 80, 68, 3, 73, 85, 82, 80, 3, 54],\n", " [80, 68, 3, 73, 85, 82, 80, 3, 54, 83],\n", " [68, 3, 73, 85, 82, 80, 3, 54, 83, 68],\n", " [3, 73, 85, 82, 80, 3, 54, 83, 68, 85],\n", " [73, 85, 82, 80, 3, 54, 83, 68, 85, 78],\n", " [85, 82, 80, 3, 54, 83, 68, 85, 78, 3],\n", " [82, 80, 3, 54, 83, 68, 85, 78, 3, 87],\n", " [80, 3, 54, 83, 68, 85, 78, 3, 87, 82],\n", " [3, 54, 83, 68, 85, 78, 3, 87, 82, 3],\n", " [54, 83, 68, 85, 78, 3, 87, 82, 3, 36],\n", " [83, 68, 85, 78, 3, 87, 82, 3, 36, 85],\n", " [68, 85, 78, 3, 87, 82, 3, 36, 85, 85],\n", " [85, 78, 3, 87, 82, 3, 36, 85, 85, 82],\n", " [78, 3, 87, 82, 3, 36, 85, 85, 82, 90],\n", " [3, 87, 82, 3, 36, 85, 85, 82, 90, 2],\n", " [87, 82, 3, 36, 85, 85, 82, 90, 2, 3],\n", " [82, 3, 36, 85, 85, 82, 90, 2, 3, 3],\n", " [3, 36, 85, 85, 82, 90, 2, 3, 3, 3],\n", " [36, 85, 85, 82, 90, 2, 3, 3, 3, 3],\n", " [85, 85, 82, 90, 2, 3, 3, 3, 3, 5],\n", " [85, 82, 90, 2, 3, 3, 3, 3, 5, 5],\n", " [82, 90, 2, 3, 3, 3, 3, 5, 5, 5],\n", " [90, 2, 3, 3, 3, 3, 5, 5, 5, 2],\n", " [2, 3, 3, 3, 3, 5, 5, 5, 2, 3],\n", " [3, 3, 3, 3, 5, 5, 5, 2, 3, 3],\n", " [3, 3, 3, 5, 5, 5, 2, 3, 3, 3],\n", " [3, 3, 5, 5, 5, 2, 3, 3, 3, 3],\n", " [3, 5, 5, 5, 2, 3, 3, 3, 3, 76],\n", " [5, 5, 5, 2, 3, 3, 3, 3, 76, 80],\n", " [5, 5, 2, 3, 3, 3, 3, 76, 80, 83],\n", " [5, 2, 3, 3, 3, 3, 76, 80, 83, 82],\n", " [2, 3, 3, 3, 3, 76, 80, 83, 82, 85],\n", " [3, 3, 3, 3, 76, 80, 83, 82, 85, 87],\n", " [3, 3, 3, 76, 80, 83, 82, 85, 87, 3],\n", " [3, 3, 76, 80, 83, 82, 85, 87, 3, 83],\n", " [3, 76, 80, 83, 82, 85, 87, 3, 83, 92],\n", " [76, 80, 83, 82, 85, 87, 3, 83, 92, 68],\n", " [80, 83, 82, 85, 87, 3, 83, 92, 68, 85],\n", " [83, 82, 85, 87, 3, 83, 92, 68, 85, 85],\n", " [82, 85, 87, 3, 83, 92, 68, 85, 85, 82],\n", " [85, 87, 3, 83, 92, 68, 85, 85, 82, 90],\n", " [87, 3, 83, 92, 68, 85, 85, 82, 90, 3],\n", " [3, 83, 92, 68, 85, 85, 82, 90, 3, 68],\n", " [83, 92, 68, 85, 85, 82, 90, 3, 68, 86],\n", " [92, 68, 85, 85, 82, 90, 3, 68, 86, 3],\n", " [68, 85, 85, 82, 90, 3, 68, 86, 3, 83],\n", " [85, 85, 82, 90, 3, 68, 86, 3, 83, 68],\n", " [85, 82, 90, 3, 68, 86, 3, 83, 68, 2],\n", " [82, 90, 3, 68, 86, 3, 83, 68, 2, 3],\n", " [90, 3, 68, 86, 3, 83, 68, 2, 3, 3],\n", " [3, 68, 86, 3, 83, 68, 2, 3, 3, 3],\n", " [68, 86, 3, 83, 68, 2, 3, 3, 3, 3],\n", " [86, 3, 83, 68, 2, 3, 3, 3, 3, 73],\n", " [3, 83, 68, 2, 3, 3, 3, 3, 73, 76],\n", " [83, 68, 2, 3, 3, 3, 3, 73, 76, 72],\n", " [68, 2, 3, 3, 3, 3, 73, 76, 72, 79],\n", " [2, 3, 3, 3, 3, 73, 76, 72, 79, 71],\n", " [3, 3, 3, 3, 73, 76, 72, 79, 71, 86],\n", " [3, 3, 3, 73, 76, 72, 79, 71, 86, 3],\n", " [3, 3, 73, 76, 72, 79, 71, 86, 3, 32],\n", " [3, 73, 76, 72, 79, 71, 86, 3, 32, 3],\n", " [73, 76, 72, 79, 71, 86, 3, 32, 3, 62],\n", " [76, 72, 79, 71, 86, 3, 32, 3, 62, 83],\n", " [72, 79, 71, 86, 3, 32, 3, 62, 83, 68],\n", " [79, 71, 86, 3, 32, 3, 62, 83, 68, 17],\n", " [71, 86, 3, 32, 3, 62, 83, 68, 17, 73],\n", " [86, 3, 32, 3, 62, 83, 68, 17, 73, 76],\n", " [3, 32, 3, 62, 83, 68, 17, 73, 76, 72],\n", " [32, 3, 62, 83, 68, 17, 73, 76, 72, 79],\n", " [3, 62, 83, 68, 17, 73, 76, 72, 79, 71],\n", " [62, 83, 68, 17, 73, 76, 72, 79, 71, 11],\n", " [83, 68, 17, 73, 76, 72, 79, 71, 11, 73],\n", " [68, 17, 73, 76, 72, 79, 71, 11, 73, 76],\n", " [17, 73, 76, 72, 79, 71, 11, 73, 76, 72],\n", " [73, 76, 72, 79, 71, 11, 73, 76, 72, 79],\n", " [76, 72, 79, 71, 11, 73, 76, 72, 79, 71],\n", " [72, 79, 71, 11, 73, 76, 72, 79, 71, 17],\n", " [79, 71, 11, 73, 76, 72, 79, 71, 17, 81],\n", " [71, 11, 73, 76, 72, 79, 71, 17, 81, 68],\n", " [11, 73, 76, 72, 79, 71, 17, 81, 68, 80],\n", " [73, 76, 72, 79, 71, 17, 81, 68, 80, 72],\n", " [76, 72, 79, 71, 17, 81, 68, 80, 72, 15],\n", " [72, 79, 71, 17, 81, 68, 80, 72, 15, 3],\n", " [79, 71, 17, 81, 68, 80, 72, 15, 3, 87],\n", " [71, 17, 81, 68, 80, 72, 15, 3, 87, 82],\n", " [17, 81, 68, 80, 72, 15, 3, 87, 82, 66],\n", " [81, 68, 80, 72, 15, 3, 87, 82, 66, 68],\n", " [68, 80, 72, 15, 3, 87, 82, 66, 68, 85],\n", " [80, 72, 15, 3, 87, 82, 66, 68, 85, 85],\n", " [72, 15, 3, 87, 82, 66, 68, 85, 85, 82],\n", " [15, 3, 87, 82, 66, 68, 85, 85, 82, 90],\n", " [3, 87, 82, 66, 68, 85, 85, 82, 90, 66],\n", " [87, 82, 66, 68, 85, 85, 82, 90, 66, 87],\n", " [82, 66, 68, 85, 85, 82, 90, 66, 87, 92],\n", " [66, 68, 85, 85, 82, 90, 66, 87, 92, 83],\n", " [68, 85, 85, 82, 90, 66, 87, 92, 83, 72],\n", " [85, 85, 82, 90, 66, 87, 92, 83, 72, 11],\n", " [85, 82, 90, 66, 87, 92, 83, 72, 11, 73],\n", " [82, 90, 66, 87, 92, 83, 72, 11, 73, 76],\n", " [90, 66, 87, 92, 83, 72, 11, 73, 76, 72],\n", " [66, 87, 92, 83, 72, 11, 73, 76, 72, 79],\n", " [87, 92, 83, 72, 11, 73, 76, 72, 79, 71],\n", " [92, 83, 72, 11, 73, 76, 72, 79, 71, 17],\n", " [83, 72, 11, 73, 76, 72, 79, 71, 17, 71],\n", " [72, 11, 73, 76, 72, 79, 71, 17, 71, 68],\n", " [11, 73, 76, 72, 79, 71, 17, 71, 68, 87],\n", " [73, 76, 72, 79, 71, 17, 71, 68, 87, 68],\n", " [76, 72, 79, 71, 17, 71, 68, 87, 68, 55],\n", " [72, 79, 71, 17, 71, 68, 87, 68, 55, 92],\n", " [79, 71, 17, 71, 68, 87, 68, 55, 92, 83],\n", " [71, 17, 71, 68, 87, 68, 55, 92, 83, 72],\n", " [17, 71, 68, 87, 68, 55, 92, 83, 72, 12],\n", " [71, 68, 87, 68, 55, 92, 83, 72, 12, 15],\n", " [68, 87, 68, 55, 92, 83, 72, 12, 15, 3],\n", " [87, 68, 55, 92, 83, 72, 12, 15, 3, 81],\n", " [68, 55, 92, 83, 72, 12, 15, 3, 81, 88],\n", " [55, 92, 83, 72, 12, 15, 3, 81, 88, 79],\n", " [92, 83, 72, 12, 15, 3, 81, 88, 79, 79],\n", " [83, 72, 12, 15, 3, 81, 88, 79, 79, 68],\n", " [72, 12, 15, 3, 81, 88, 79, 79, 68, 69],\n", " [12, 15, 3, 81, 88, 79, 79, 68, 69, 79],\n", " [15, 3, 81, 88, 79, 79, 68, 69, 79, 72],\n", " [3, 81, 88, 79, 79, 68, 69, 79, 72, 32],\n", " [81, 88, 79, 79, 68, 69, 79, 72, 32, 73],\n", " [88, 79, 79, 68, 69, 79, 72, 32, 73, 76],\n", " [79, 79, 68, 69, 79, 72, 32, 73, 76, 72],\n", " [79, 68, 69, 79, 72, 32, 73, 76, 72, 79],\n", " [68, 69, 79, 72, 32, 73, 76, 72, 79, 71],\n", " [69, 79, 72, 32, 73, 76, 72, 79, 71, 17],\n", " [79, 72, 32, 73, 76, 72, 79, 71, 17, 81],\n", " [72, 32, 73, 76, 72, 79, 71, 17, 81, 88],\n", " [32, 73, 76, 72, 79, 71, 17, 81, 88, 79],\n", " [73, 76, 72, 79, 71, 17, 81, 88, 79, 79],\n", " [76, 72, 79, 71, 17, 81, 88, 79, 79, 68],\n", " [72, 79, 71, 17, 81, 88, 79, 79, 68, 69],\n", " [79, 71, 17, 81, 88, 79, 79, 68, 69, 79],\n", " [71, 17, 81, 88, 79, 79, 68, 69, 79, 72],\n", " [17, 81, 88, 79, 79, 68, 69, 79, 72, 12],\n", " [81, 88, 79, 79, 68, 69, 79, 72, 12, 2],\n", " [88, 79, 79, 68, 69, 79, 72, 12, 2, 3],\n", " [79, 79, 68, 69, 79, 72, 12, 2, 3, 3],\n", " [79, 68, 69, 79, 72, 12, 2, 3, 3, 3],\n", " [68, 69, 79, 72, 12, 2, 3, 3, 3, 3],\n", " [69, 79, 72, 12, 2, 3, 3, 3, 3, 3],\n", " [79, 72, 12, 2, 3, 3, 3, 3, 3, 3],\n", " [72, 12, 2, 3, 3, 3, 3, 3, 3, 3],\n", " [12, 2, 3, 3, 3, 3, 3, 3, 3, 3],\n", " [2, 3, 3, 3, 3, 3, 3, 3, 3, 3],\n", " [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],\n", " [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],\n", " [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],\n", " [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],\n", " [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],\n", " [3, 3, 3, 3, 3, 3, 3, 3, 3, 73],\n", " [3, 3, 3, 3, 3, 3, 3, 3, 73, 82],\n", " [3, 3, 3, 3, 3, 3, 3, 73, 82, 85],\n", " [3, 3, 3, 3, 3, 3, 73, 82, 85, 3],\n", " [3, 3, 3, 3, 3, 73, 82, 85, 3, 73],\n", " [3, 3, 3, 3, 73, 82, 85, 3, 73, 76],\n", " [3, 3, 3, 73, 82, 85, 3, 73, 76, 72],\n", " [3, 3, 73, 82, 85, 3, 73, 76, 72, 79],\n", " [3, 73, 82, 85, 3, 73, 76, 72, 79, 71],\n", " [73, 82, 85, 3, 73, 76, 72, 79, 71, 3],\n", " [82, 85, 3, 73, 76, 72, 79, 71, 3, 76],\n", " [85, 3, 73, 76, 72, 79, 71, 3, 76, 81],\n", " [3, 73, 76, 72, 79, 71, 3, 76, 81, 3],\n", " [73, 76, 72, 79, 71, 3, 76, 81, 3, 86],\n", " [76, 72, 79, 71, 3, 76, 81, 3, 86, 70],\n", " [72, 79, 71, 3, 76, 81, 3, 86, 70, 75],\n", " [79, 71, 3, 76, 81, 3, 86, 70, 75, 72],\n", " [71, 3, 76, 81, 3, 86, 70, 75, 72, 80],\n", " [3, 76, 81, 3, 86, 70, 75, 72, 80, 68],\n", " [76, 81, 3, 86, 70, 75, 72, 80, 68, 64],\n", " [81, 3, 86, 70, 75, 72, 80, 68, 64, 2],\n", " [3, 86, 70, 75, 72, 80, 68, 64, 2, 3],\n", " [86, 70, 75, 72, 80, 68, 64, 2, 3, 3],\n", " [70, 75, 72, 80, 68, 64, 2, 3, 3, 3],\n", " [75, 72, 80, 68, 64, 2, 3, 3, 3, 3],\n", " [72, 80, 68, 64, 2, 3, 3, 3, 3, 85],\n", " [80, 68, 64, 2, 3, 3, 3, 3, 85, 72],\n", " [68, 64, 2, 3, 3, 3, 3, 85, 72, 87],\n", " [64, 2, 3, 3, 3, 3, 85, 72, 87, 88],\n", " [2, 3, 3, 3, 3, 85, 72, 87, 88, 85],\n", " [3, 3, 3, 3, 85, 72, 87, 88, 85, 81],\n", " [3, 3, 3, 85, 72, 87, 88, 85, 81, 3],\n", " [3, 3, 85, 72, 87, 88, 85, 81, 3, 83],\n", " [3, 85, 72, 87, 88, 85, 81, 3, 83, 68],\n", " [85, 72, 87, 88, 85, 81, 3, 83, 68, 17],\n", " [72, 87, 88, 85, 81, 3, 83, 68, 17, 86],\n", " [87, 88, 85, 81, 3, 83, 68, 17, 86, 70],\n", " [88, 85, 81, 3, 83, 68, 17, 86, 70, 75],\n", " [85, 81, 3, 83, 68, 17, 86, 70, 75, 72],\n", " [81, 3, 83, 68, 17, 86, 70, 75, 72, 80],\n", " [3, 83, 68, 17, 86, 70, 75, 72, 80, 68],\n", " [83, 68, 17, 86, 70, 75, 72, 80, 68, 11],\n", " [68, 17, 86, 70, 75, 72, 80, 68, 11, 73],\n", " [17, 86, 70, 75, 72, 80, 68, 11, 73, 76],\n", " [86, 70, 75, 72, 80, 68, 11, 73, 76, 72],\n", " [70, 75, 72, 80, 68, 11, 73, 76, 72, 79],\n", " [75, 72, 80, 68, 11, 73, 76, 72, 79, 71],\n", " [72, 80, 68, 11, 73, 76, 72, 79, 71, 86],\n", " [80, 68, 11, 73, 76, 72, 79, 71, 86, 12]],\n", " 'labels': [71,\n", " 72,\n", " 73,\n", " 3,\n", " 87,\n", " 82,\n", " 66,\n", " 68,\n", " 85,\n", " 85,\n", " 82,\n", " 90,\n", " 66,\n", " 86,\n", " 70,\n", " 75,\n", " 72,\n", " 80,\n", " 68,\n", " 11,\n", " 86,\n", " 70,\n", " 75,\n", " 72,\n", " 80,\n", " 68,\n", " 12,\n", " 29,\n", " 2,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 5,\n", " 5,\n", " 5,\n", " 3,\n", " 38,\n", " 82,\n", " 81,\n", " 89,\n", " 72,\n", " 85,\n", " 87,\n", " 3,\n", " 68,\n", " 3,\n", " 86,\n", " 70,\n", " 75,\n", " 72,\n", " 80,\n", " 68,\n", " 3,\n", " 73,\n", " 85,\n", " 82,\n", " 80,\n", " 3,\n", " 54,\n", " 83,\n", " 68,\n", " 85,\n", " 78,\n", " 3,\n", " 87,\n", " 82,\n", " 3,\n", " 36,\n", " 85,\n", " 85,\n", " 82,\n", " 90,\n", " 2,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 5,\n", " 5,\n", " 5,\n", " 2,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 76,\n", " 80,\n", " 83,\n", " 82,\n", " 85,\n", " 87,\n", " 3,\n", " 83,\n", " 92,\n", " 68,\n", " 85,\n", " 85,\n", " 82,\n", " 90,\n", " 3,\n", " 68,\n", " 86,\n", " 3,\n", " 83,\n", " 68,\n", " 2,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 73,\n", " 76,\n", " 72,\n", " 79,\n", " 71,\n", " 86,\n", " 3,\n", " 32,\n", " 3,\n", " 62,\n", " 83,\n", " 68,\n", " 17,\n", " 73,\n", " 76,\n", " 72,\n", " 79,\n", " 71,\n", " 11,\n", " 73,\n", " 76,\n", " 72,\n", " 79,\n", " 71,\n", " 17,\n", " 81,\n", " 68,\n", " 80,\n", " 72,\n", " 15,\n", " 3,\n", " 87,\n", " 82,\n", " 66,\n", " 68,\n", " 85,\n", " 85,\n", " 82,\n", " 90,\n", " 66,\n", " 87,\n", " 92,\n", " 83,\n", " 72,\n", " 11,\n", " 73,\n", " 76,\n", " 72,\n", " 79,\n", " 71,\n", " 17,\n", " 71,\n", " 68,\n", " 87,\n", " 68,\n", " 55,\n", " 92,\n", " 83,\n", " 72,\n", " 12,\n", " 15,\n", " 3,\n", " 81,\n", " 88,\n", " 79,\n", " 79,\n", " 68,\n", " 69,\n", " 79,\n", " 72,\n", " 32,\n", " 73,\n", " 76,\n", " 72,\n", " 79,\n", " 71,\n", " 17,\n", " 81,\n", " 88,\n", " 79,\n", " 79,\n", " 68,\n", " 69,\n", " 79,\n", " 72,\n", " 12,\n", " 2,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 73,\n", " 82,\n", " 85,\n", " 3,\n", " 73,\n", " 76,\n", " 72,\n", " 79,\n", " 71,\n", " 3,\n", " 76,\n", " 81,\n", " 3,\n", " 86,\n", " 70,\n", " 75,\n", " 72,\n", " 80,\n", " 68,\n", " 64,\n", " 2,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 85,\n", " 72,\n", " 87,\n", " 88,\n", " 85,\n", " 81,\n", " 3,\n", " 83,\n", " 68,\n", " 17,\n", " 86,\n", " 70,\n", " 75,\n", " 72,\n", " 80,\n", " 68,\n", " 11,\n", " 73,\n", " 76,\n", " 72,\n", " 79,\n", " 71,\n", " 86,\n", " 12,\n", " 1]}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def process(data):\n", " '''\n", " 在datasets的map里使用,将文本转换成训练数据\n", " '''\n", " text = data['whole_func_string']\n", " # 如果是普通的map操作,传入的值是字符串\n", " if isinstance(text, str):\n", " inputs, labels = autoregressive_trans(text, tok)\n", " return {'inputs': inputs, 'labels': labels}\n", " # 如果是map操作里面batched=True,传入的值是字符串列表\n", " inputs, labels = [], []\n", " for i in text:\n", " i, l = autoregressive_trans(i, tok)\n", " inputs += i\n", " labels += l\n", " return {'inputs': inputs, 'labels': labels}\n", "\n", "process(datasets[8:9])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "U0Ojs8kbsiSS", "outputId": "0e571fc8-2432-4655-d276-538a79031000" }, "outputs": [ { "data": { "text/plain": [ "(torch.Size([645401, 10]), torch.Size([645401]))" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 将数据分为训练集和测试集\n", "tokenized = datasets.train_test_split(test_size=0.1, seed=1024, shuffle=True)\n", "# 将文本转换为训练数据,里面包含inputs和labels\n", "tokenized = tokenized.map(process, batched=True, remove_columns=datasets.column_names)\n", "tokenized.set_format(type='torch', device=device)\n", "\n", "tokenized['train']['inputs'].shape, tokenized['train']['labels'].shape" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "SdfNybxcsiSS", "outputId": "d8276b00-2799-4a1c-d0b5-41c9e515803f" }, "outputs": [ { "data": { "text/plain": [ "{'inputs': tensor([[38, 85, 72, ..., 12, 2, 3],\n", " [ 3, 76, 73, ..., 86, 87, 66],\n", " [80, 72, 15, ..., 75, 15, 3],\n", " ...,\n", " [75, 76, 86, ..., 79, 3, 68],\n", " [87, 68, 81, ..., 82, 90, 15],\n", " [ 3, 3, 3, ..., 68, 85, 68]], device='cuda:0'),\n", " 'labels': tensor([ 3, 86, 86, 90, 72, 87, 3, 81, 76, 29, 74, 87, 87, 87, 71, 3, 3, 3,\n", " 15, 3, 68, 3, 87, 75, 86, 17, 75, 13, 3, 80, 3, 83, 86, 68, 3, 83,\n", " 3, 81, 53, 83, 11, 66, 3, 85, 85, 82, 76, 85, 73, 14, 3, 75, 3, 3,\n", " 88, 3, 3, 3, 79, 82, 3, 50, 58, 11, 17, 3, 86, 93, 3, 87, 3, 74,\n", " 78, 78, 80, 54, 3, 2, 81, 80, 3, 81, 3, 3, 74, 3, 76, 82, 82, 3,\n", " 60, 12, 5, 73, 75, 72, 87, 3, 66, 73, 69, 85, 3, 76, 3, 3, 3, 86,\n", " 87, 72, 2, 3, 70, 85, 3, 3, 72, 32, 3, 3, 68, 3, 79, 72, 88, 3,\n", " 70, 76, 11, 3, 80, 72, 3, 87, 71, 3, 76, 73, 11, 82, 8, 84, 3, 72,\n", " 3, 76, 3, 49, 79, 87, 3, 3, 73, 12, 79, 90, 82, 83, 3, 71, 86, 91,\n", " 3, 75, 87, 83, 75, 72, 87, 76, 3, 3, 71, 92, 17, 87, 10, 87, 6, 70,\n", " 10, 3, 88, 3, 81, 86, 12, 72, 3, 3, 81, 72, 19, 68, 76, 44, 73, 82,\n", " 70, 82, 68, 3, 88, 86, 3, 5, 67, 72, 79, 70, 81, 3, 3, 75, 87, 3,\n", " 3, 76, 6, 3, 72, 70, 87, 3, 3, 87, 71, 3, 68, 2, 83, 85, 71, 87,\n", " 72, 33, 15, 87, 74, 3, 3, 29, 68, 3, 75, 85, 80, 3, 70, 72, 70, 82,\n", " 73, 82, 70, 86, 15, 3, 68, 5, 86, 82, 3, 3, 3, 3, 87, 3, 17, 76,\n", " 75, 87, 93, 3, 2, 81, 90, 3, 91, 72, 19, 3, 3, 69, 49, 3, 71, 81,\n", " 70, 5, 33, 3, 87, 3, 76, 20, 29, 3, 71, 68, 88, 86, 86, 17, 77, 82,\n", " 88, 3, 73, 3, 76, 3, 3, 81, 79, 72, 73, 82, 3, 82, 83, 72, 54, 20,\n", " 72, 72, 81, 2, 3, 80, 87, 85, 83, 80, 87, 72, 72, 76, 82, 3, 27, 81,\n", " 3, 3, 3, 86, 80, 66, 17, 76, 71, 80, 87, 17, 79, 83, 76, 83, 3, 3,\n", " 72, 76, 3, 87, 87, 46, 11, 68, 72, 3, 3, 70, 3, 81, 3, 2, 3, 68,\n", " 3, 72, 3, 68, 3, 72, 73, 3, 82, 83, 3, 32, 68, 2, 68, 73, 76, 11,\n", " 68, 68, 3, 3, 2, 3, 72, 17, 3, 55, 51, 3, 3, 87, 87, 3, 72, 87,\n", " 3, 3, 15, 89, 82, 3, 73, 81, 3, 85, 3, 74, 17, 75, 85, 3, 80, 79,\n", " 3, 85, 3, 3, 15, 88, 3, 12, 77, 83, 87, 86, 74, 72, 91, 3, 3, 79,\n", " 71, 29, 72, 3, 72, 3, 68, 81, 73, 3, 71, 74, 86, 3, 87, 3, 79, 17,\n", " 85, 76, 3, 12, 85, 72, 72, 5, 3, 3, 90, 3, 17, 3, 3, 72, 86, 80,\n", " 3, 82, 86, 68, 72, 3, 69, 71, 86, 85, 85, 3, 77, 10, 71, 76, 3, 3,\n", " 3, 85, 87, 5, 76, 81, 3, 3, 90, 68, 71, 81, 85, 80, 3, 79, 68, 17,\n", " 79, 3, 86, 3, 70, 73, 1, 3, 3, 3, 3, 87, 86, 3, 86, 68, 3, 87,\n", " 72, 3, 2, 86, 72, 87, 79, 71, 81, 15, 86, 85, 87, 78, 86, 74, 3, 66,\n", " 79, 82, 76, 80, 82, 2, 79, 70, 74, 51, 68, 66, 3, 3, 79, 3, 71, 2,\n", " 86, 85, 3, 68, 3, 5, 39, 71, 72, 81, 87, 3, 3, 3, 3, 3, 3, 72,\n", " 3, 72, 88, 81, 3, 85, 82, 68, 39, 3, 72, 68, 2, 3, 68, 70, 80, 69,\n", " 72, 2, 49, 82, 82, 81, 90, 3, 76, 2, 66, 72, 3, 3, 3, 3, 88, 72,\n", " 3, 19, 3, 72, 81, 3, 76, 78, 3, 3, 76, 76, 89, 81, 80, 87, 3, 81,\n", " 68, 81, 79, 2, 11, 71, 3, 81, 72, 49, 87, 72, 76, 5, 73, 68, 3, 3,\n", " 79, 72, 71, 21, 3, 85, 17, 3, 3, 3, 79, 76, 3, 3, 68, 82, 3, 87,\n", " 81, 3, 87, 72, 3, 87, 21, 44, 3, 87, 68, 81, 3, 68, 82, 87, 3, 70,\n", " 3, 12, 72, 87, 85, 2, 3, 3, 75, 82, 50, 3, 79, 85, 68, 64, 3, 62,\n", " 8, 11, 79, 72, 3, 76, 85, 15, 33, 32, 68, 3, 5, 3, 76, 21, 88, 79,\n", " 49, 68, 3, 88, 51, 3, 3, 80, 68, 79, 3, 3, 2, 32, 3, 12, 2, 68,\n", " 3, 81, 17, 87, 71, 3, 3, 76, 76, 92, 3, 88, 3, 3, 49, 80, 3, 85,\n", " 3, 3, 19, 79, 49, 87, 3, 3, 76, 80, 3, 68, 2, 71, 11, 3, 13, 3,\n", " 72, 29, 76, 86, 82, 3, 29, 32, 3, 10, 81, 3, 3, 3, 87, 85, 17, 85,\n", " 10, 11, 85, 78, 3, 17, 78, 70, 21, 3, 83, 3, 3, 90, 3, 85, 3, 82,\n", " 71, 88, 3, 63, 80, 89, 54, 69, 11, 15, 3, 81, 3, 69, 72, 68, 3, 88,\n", " 42, 68, 3, 38, 19, 12, 41, 76, 87, 11, 3, 87, 3, 76, 79, 72, 82, 3,\n", " 70, 2, 2, 67, 3, 15, 80, 87, 2, 54, 2, 11, 11, 3, 70, 87, 87, 70,\n", " 75, 12, 11, 82, 87, 2, 81, 70, 70, 21, 82, 81, 3, 72, 86, 72, 2, 47,\n", " 87, 3, 89, 78, 29, 85, 68, 12, 86, 82, 49, 3, 3, 81, 3, 73, 76, 83,\n", " 3, 68, 79, 3, 72, 3, 3, 3, 72, 3, 83, 63, 3, 86, 82, 73, 71, 72,\n", " 54, 72, 3, 17, 92, 73, 3, 86, 3, 82, 15, 3, 72, 74, 3, 72, 3, 17,\n", " 3, 72, 81, 75, 72, 3, 3, 3, 87, 37, 75, 73, 3, 5, 86, 2, 5, 81,\n", " 68, 3, 15, 2, 85, 85, 3, 3, 3, 72, 3, 80, 72, 68, 3, 85, 11, 83,\n", " 62, 72, 3, 81, 86, 2, 75, 79, 3, 80], device='cuda:0')}" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 构建数据读取器\n", "train_loader = DataLoader(tokenized['train'], batch_size=batch_size, shuffle=True)\n", "test_loader = DataLoader(tokenized['test'], batch_size=batch_size, shuffle=True)\n", "# 获取一个批量的数据\n", "next(iter(test_loader))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "WVXSXA6vsiSS" }, "outputs": [], "source": [ "class CharMLP(nn.Module):\n", "\n", " def __init__(self, vs):\n", " '''\n", " 根据文本背景预测下一个字母是什么\n", " 参数\n", " ----\n", " vs :int,字典大小\n", " '''\n", " super().__init__()\n", " # 文字嵌入层\n", " self.embedding = nn.Embedding(vs, 30)\n", " self.hidden1 = nn.Linear(300, 200)\n", " self.hidden2 = nn.Linear(200, 100)\n", " self.out = nn.Linear(100, vs)\n", "\n", " def forward(self, x):\n", " '''\n", " 向前传播\n", " 参数\n", " ----\n", " x :torch.LongTensor,背景文本,其中的元素表示相应位置的字母在字典中的位置\n", " 返回\n", " ----\n", " h :torch.FloatTensor,预测结果的logits\n", " '''\n", " # 因为背景文本的长度(context_length)等于10,\n", " # 所以x的形状是(B, 10),B表示批量数据的大小\n", " B = x.shape[0] # (B, 10)\n", " emb = self.embedding(x) # (B, 10, 30)\n", " h = emb.view(B, -1) # (B, 300)\n", " h = F.relu(self.hidden1(h)) # (B, 200)\n", " h = F.relu(self.hidden2(h)) # (B, 100)\n", " h = self.out(h) # (B, vs)\n", " return h\n", "\n", "model = CharMLP(len(tok.char2ind)).to(device)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "82S3F9IXsiSS" }, "outputs": [], "source": [ "@torch.no_grad()\n", "def generate(model, context, max_new_tokens=300):\n", " '''\n", " 利用模型生成文本(反复使用模型进行预测)\n", " 参数\n", " ----\n", " model :CharMLP,生成文本的模型\n", " context :torch.LongTensor,背景文本,形状为(1, 10)\n", " max_new_tokens :int,生成文本的最大长度\n", " 返回\n", " ----\n", " out :list[int],生成的文本\n", " '''\n", " out = []\n", " # 将模型切换至评估模式\n", " model.eval()\n", " for _ in range(max_new_tokens):\n", " logits = model(context)\n", " probs = F.softmax(logits, dim=-1)\n", " # 根据模型预测的概率,得到最终的预测结果(下一个字母)\n", " # 这一步运算有一定随机性\n", " ix = torch.multinomial(probs, num_samples=1)\n", " # 利用模型的预测结果更新文本背景\n", " context = torch.cat((context[:, 1:], ix), dim=1)\n", " out.append(ix.item())\n", " if ix.item() == tok.end_ind:\n", " break\n", " # 将模型切换至训练模式\n", " model.train()\n", " return out" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "4zSHFccTsiST", "outputId": "e0a6ca3f-fe74-49c6-caf1-0080eef4b158" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ")YN'E.ne'!XOzAYD\n", "F{tvö290&^#>P8(MZzJP\n" ] } ], "source": [ "# 使用模型来生成文本\n", "context = torch.zeros((1, 10), dtype=torch.long, device=device)\n", "print(''.join(tok.decode(generate(model, context))))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "bo3UzjmXsiST", "outputId": "37337a6e-a3c0-40aa-84cb-6b88ec257da7" }, "outputs": [ { "data": { "text/plain": [ "{'train': 4.5956830978393555, 'test': 4.594418525695801}" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def estimate_loss(model):\n", " re = {}\n", " # 将模型切换至评估模式\n", " model.eval()\n", " re['train'] = _loss(model, train_loader)\n", " re['test'] = _loss(model, test_loader)\n", " # 将模型切换至训练模式\n", " model.train()\n", " return re\n", "\n", "@torch.no_grad()\n", "def _loss(model, data_loader):\n", " \"\"\"\n", " 计算模型在不同数据集下面的评估指标\n", " \"\"\"\n", " loss = []\n", " data_iter= iter(data_loader)\n", " # 随机使用多个批量数据来预估模型效果\n", " for k in range(eval_iters):\n", " data = next(data_iter, None)\n", " if data is None:\n", " data_iter = iter(data_loader)\n", " data = next(data_iter, None)\n", " inputs, labels = data['inputs'], data['labels']\n", " logits = model(inputs)\n", " loss.append(F.cross_entropy(logits, labels).item())\n", " return torch.tensor(loss).mean().item()\n", "\n", "estimate_loss(model)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "18nkKzWLsiST" }, "outputs": [], "source": [ "def train_mlp(model, optimizer, data_loader, epochs=10):\n", " # 记录模型在训练集上的模型损失\n", " lossi = []\n", " for epoch in range(epochs):\n", " for i, data in enumerate(data_loader, 0):\n", " inputs, labels = data['inputs'], data['labels']\n", " optimizer.zero_grad()\n", " logits = model(inputs)\n", " loss = F.cross_entropy(logits, labels)\n", " lossi.append(loss.item())\n", " loss.backward()\n", " optimizer.step()\n", " # 评估模型,并输出结果\n", " stats = estimate_loss(model)\n", " train_loss = f'train loss {stats[\"train\"]:.4f}'\n", " test_loss = f'test loss {stats[\"test\"]:.4f}'\n", " print(f'epoch {epoch:>2}: {train_loss}, {test_loss}')\n", " return lossi" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "n_af5vmnsiST", "outputId": "80a2167d-ff62-4261-b7c8-4e63f2b7287f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0: train loss 1.3726, test loss 1.5097\n", "epoch 1: train loss 1.2598, test loss 1.4965\n", "epoch 2: train loss 1.1934, test loss 1.4247\n", "epoch 3: train loss 1.1630, test loss 1.4014\n", "epoch 4: train loss 1.1505, test loss 1.3658\n", "epoch 5: train loss 1.1539, test loss 1.3594\n", "epoch 6: train loss 1.0862, test loss 1.3975\n", "epoch 7: train loss 1.0872, test loss 1.3718\n", "epoch 8: train loss 1.0707, test loss 1.3832\n", "epoch 9: train loss 1.0845, test loss 1.3286\n" ] } ], "source": [ "l = train_mlp(model, optim.Adam(model.parameters(), lr=learning_rate), train_loader)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 448 }, "id": "a4tA30XfiMFv", "outputId": "d8065926-15f2-4f0e-e3dd-2a5e3ef16768" }, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABL2UlEQVR4nO3dd3zTdf4H8FdGk86ki+5BoezSsqFspCw5BM+BiIKKKByoOE/0zvnTese5zlPQQ0VFxAl4yF5FoOxWWkahUGiBDtrSpDMd+f7+SPNt0qYL2n5p83o+HnnQfPNN+skXaF79jPdHJgiCACIiIiKJyKVuABEREdk3hhEiIiKSFMMIERERSYphhIiIiCTFMEJERESSYhghIiIiSTGMEBERkaQYRoiIiEhSSqkb0BRGoxFXr16Fm5sbZDKZ1M0hIiKiJhAEAYWFhQgICIBcXn//R7sII1evXkVwcLDUzSAiIqIbkJGRgaCgoHofbxdhxM3NDYDpzWg0GolbQ0RERE2h1+sRHBwsfo7Xp12EEfPQjEajYRghIiJqZxqbYsEJrERERCQphhEiIiKSFMMIERERSYphhIiIiCTFMEJERESSYhghIiIiSTGMEBERkaQYRoiIiEhSDCNEREQkKYYRIiIikhTDCBEREUmKYYSIiIgk1S42ymstn+9LQ0Z+Ce4bEoyeftyAj4iISAp23TPy24mrWHXgItLzSqRuChERkd2y6zAir97S2ChI3BAiIiI71qwwsnz5ckRGRkKj0UCj0SA6OhqbN2+u9/xVq1ZBJpNZ3RwdHW+60S3FHEYEgWmEiIhIKs2aMxIUFIR33nkH3bp1gyAI+OqrrzB9+nQkJCSgT58+Np+j0WiQkpIi3pdVB4Bbgbkp7BkhIiKSTrPCyLRp06zuv/XWW1i+fDkOHjxYbxiRyWTw8/O78Ra2InPPSBV7RoiIiCRzw3NGqqqqsHbtWhQXFyM6Orre84qKihAaGorg4GBMnz4dJ0+ebPS1DQYD9Hq91a01KOQcpiEiIpJas8NIUlISXF1doVarsWDBAqxbtw69e/e2eW6PHj3wxRdfYMOGDVi9ejWMRiOGDx+Oy5cvN/g9YmNjodVqxVtwcHBzm9kkNcM0DCNERERSkQnN7BYoLy9Heno6dDodfvrpJ6xcuRJxcXH1BhJLFRUV6NWrF2bNmoU333yz3vMMBgMMBoN4X6/XIzg4GDqdDhpNy9UDmfvFYcSdvYZ374nCXQODWux1iYiIyPT5rdVqG/38bnbRM5VKhfDwcADAwIEDceTIEXz44Yf49NNPG32ug4MD+vfvj9TU1AbPU6vVUKvVzW1as8nZM0JERCS5m64zYjQarXoxGlJVVYWkpCT4+/vf7LdtETV1RhhGiIiIpNKsnpGlS5diypQpCAkJQWFhIdasWYM9e/Zg69atAIA5c+YgMDAQsbGxAIA33ngDw4YNQ3h4OAoKCrBs2TJcunQJjz76aMu/kxsgY9EzIiIiyTUrjOTk5GDOnDnIzMyEVqtFZGQktm7digkTJgAA0tPTIZfXdLZcv34d8+fPR1ZWFjw8PDBw4EAcOHCgSfNL2oKiuqnsGSEiIpJOsyewSqGpE2Caa+HqY9icnIU3Z0TgwWGhLfa6RERE1PTPb+5NA9YZISIikpJdhxGxzggnjRAREUnGrsNITTl4iRtCRERkx+w6jLAcPBERkfTsOoywHDwREZH07DqMyFlnhIiISHJ2HkZMf7JnhIiISDp2Hkaqe0bYNUJERCQZuw4jLAdPREQkPbsOIywHT0REJD27DiOcwEpERCQ9hhGwzggREZGU7DqMsM4IERGR9Ow6jIjl4I0SN4SIiMiO2XUYYTl4IiIi6dl1GOEwDRERkfTsOoxwNQ0REZH07DyMmP5kzwgREZF07DyMmOeMSNwQIiIiO2bXYUQmrqZhGiEiIpKKXYcRhThnhGGEiIhIKnYdRmrmjEjbDiIiIntm32GEdUaIiIgkZ9dhhHVGiIiIpGfXYYTl4ImIiKRn52HE9CeHaYiIiKRj52GEq2mIiIikxjACrqYhIiKSkp2HEdOf7BkhIiKSjn2HETnLwRMREUnNrsMIy8ETERFJz67DCMvBExERSc+uwwjLwRMREUnPzsMIy8ETERFJza7DCMvBExERSc+uw4hYDp5ZhIiISDL2HUaq3z2HaYiIiKRj32GEq2mIiIgkxzACwMhde4mIiCTDMAL2jBAREUnJzsOI6U9mESIiIunYdRgRy8EzjRAREUnGrsOIQs5hGiIiIqnZdRhhOXgiIiLp2XkYYTl4IiIiqdl1GGE5eCIiIunZdRhhnREiIiLpMYyAPSNERERSsu8wUv3uGUaIiIikY99hROwZkbghREREdoxhBOwZISIikpKdhxHTn8wiRERE0rHrMCKWg+c4DRERkWTsOozIWWeEiIhIcnYdRsx70zCLEBERSceuwwgnsBIREUnPrsMIy8ETERFJr1lhZPny5YiMjIRGo4FGo0F0dDQ2b97c4HN+/PFH9OzZE46Ojujbty82bdp0Uw1uSawzQkREJL1mhZGgoCC88847OHbsGI4ePYrbbrsN06dPx8mTJ22ef+DAAcyaNQvz5s1DQkICZsyYgRkzZiA5OblFGn+zavamYRohIiKSikwQbm6MwtPTE8uWLcO8efPqPDZz5kwUFxdj48aN4rFhw4ahX79+WLFiRZO/h16vh1arhU6ng0ajuZnmWknNKUTMe3vh4eyAhFcmttjrEhERUdM/v294zkhVVRXWrl2L4uJiREdH2zwnPj4eMTExVscmTZqE+Pj4Bl/bYDBAr9db3VqDjMM0REREkmt2GElKSoKrqyvUajUWLFiAdevWoXfv3jbPzcrKgq+vr9UxX19fZGVlNfg9YmNjodVqxVtwcHBzm9kkXE1DREQkvWaHkR49eiAxMRGHDh3CwoULMXfuXJw6dapFG7V06VLodDrxlpGR0aKvb8Zy8ERERNJTNvcJKpUK4eHhAICBAwfiyJEj+PDDD/Hpp5/WOdfPzw/Z2dlWx7Kzs+Hn59fg91Cr1VCr1c1tWrPJWQ6eiIhIcjddZ8RoNMJgMNh8LDo6Gjt37rQ6tn379nrnmLQ11hkhIiKSXrN6RpYuXYopU6YgJCQEhYWFWLNmDfbs2YOtW7cCAObMmYPAwEDExsYCAJ566imMGTMG7777LqZOnYq1a9fi6NGj+Oyzz1r+ndwAloMnIiKSXrPCSE5ODubMmYPMzExotVpERkZi69atmDBhAgAgPT0dcnlNZ8vw4cOxZs0a/O1vf8NLL72Ebt26Yf369YiIiGjZd3GDOIGViIhIejddZ6QttFadkZzCMgx5ayfkMuBC7NQWe10iIiJqgzojHQHLwRMREUmPYaQaS8ITERFJw67DiHkCKwBUMowQERFJwq7DiIOiJoyw1ggREZE07DqMWPeMGCVsCRERkf2y6zCitFiGXFnFnhEiIiIp2HUYUchlYhVWzhkhIiKShl2HEQBQyrk/DRERkZTsPoyY541wzggREZE07D6MmOeNcM4IERGRNBhGFOaeEYYRIiIiKTCMcM4IERGRpOw+jJjnjFRUcc4IERGRFOw+jJjnjLBnhIiISBoMI5wzQkREJCm7DyMKzhkhIiKSlN2HEfME1krOGSEiIpKE3YcRhbnOCHtGiIiIJGH3YcRBwWEaIiIiKdl9GKkpB88wQkREJAW7DyOcM0JERCQtuw8j7BkhIiKSlt2HEQcFi54RERFJye7DCHtGiIiIpGX3YYRzRoiIiKTFMMI6I0RERJKy+zCiYJ0RIiIiSdl9GFFyzggREZGk7D6MKDhnhIiISFJ2H0YcOGeEiIhIUnYfRjhnhIiISFp2H0a4tJeIiEhadh9GWPSMiIhIWnYfRlgOnoiISFp2H0bYM0JERCQtuw8jnDNCREQkLbsPI+wZISIikpbdhxHOGSEiIpKW3YcR9owQERFJy+7DCOeMEBERScvuwwh7RoiIiKRl92FEyTkjREREkmIYYc8IERGRpOw+jJiHaSo4Z4SIiEgSdh9G1ErTJSivZBghIiKSgt2HEUcHBQCgrKJK4pYQERHZJ7sPI+aeEQN7RoiIiCTBMKI09YwwjBAREUnD7sOIo4PpEnCYhoiISBp2H0bYM0JERCQtuw8j5p4RA3tGiIiIJGH3YURtXk3DnhEiIiJJMIxY1BkRBFZhJSIiamt2H0bMdUYAzhshIiKSgt2HEXPPCAAYKhhGiIiI2prdhxGlXIbq7WlgqOQkViIiorZm92FEJpNZlIRnzwgREVFba1YYiY2NxeDBg+Hm5gYfHx/MmDEDKSkpDT5n1apVkMlkVjdHR8ebanRLqykJz54RIiKittasMBIXF4dFixbh4MGD2L59OyoqKjBx4kQUFxc3+DyNRoPMzEzxdunSpZtqdEtj4TMiIiLpKJtz8pYtW6zur1q1Cj4+Pjh27BhGjx5d7/NkMhn8/PxurIVtgCXhiYiIpHNTc0Z0Oh0AwNPTs8HzioqKEBoaiuDgYEyfPh0nT55s8HyDwQC9Xm91a03sGSEiIpLODYcRo9GIJUuWYMSIEYiIiKj3vB49euCLL77Ahg0bsHr1ahiNRgwfPhyXL1+u9zmxsbHQarXiLTg4+Eab2SRiSXjOGSEiImpzNxxGFi1ahOTkZKxdu7bB86KjozFnzhz069cPY8aMwS+//IJOnTrh008/rfc5S5cuhU6nE28ZGRk32swmMfeMcDUNERFR22vWnBGzxYsXY+PGjdi7dy+CgoKa9VwHBwf0798fqamp9Z6jVquhVqtvpGk3RM2eESIiIsk0q2dEEAQsXrwY69atw65duxAWFtbsb1hVVYWkpCT4+/s3+7mthT0jRERE0mlWz8iiRYuwZs0abNiwAW5ubsjKygIAaLVaODk5AQDmzJmDwMBAxMbGAgDeeOMNDBs2DOHh4SgoKMCyZctw6dIlPProoy38Vm6c2DPC1TRERERtrllhZPny5QCAsWPHWh3/8ssv8dBDDwEA0tPTIZfXdLhcv34d8+fPR1ZWFjw8PDBw4EAcOHAAvXv3vrmWt6CaomfsGSEiImprzQojgiA0es6ePXus7r///vt4//33m9WotsZy8ERERNKx+71pAJaDJyIikhLDCGp6RjhMQ0RE1PYYRlDTM8Jy8ERERG2PYQQsB09ERCQlhhFwozwiIiIpMYyAPSNERERSYhgB64wQERFJiWEElnVGOExDRETU1hhGwJ4RIiIiKTGMgHvTEBERSYlhBCx6RkREJCWGEVgM07BnhIiIqM0xjMBiAit7RoiIiNocwwjYM0JERCQlhhGw6BkREZGUGEZQUw6+0iigsoqBhIiIqC0xjKCmZwRg7wgREVFbYxhBzZwRgGGEiIiorTGMAJDLZVApTJeilJNYiYiI2hTDSDXzvJHScoYRIiKitsQwUs1ZpQTAzfKIiIjaGsNINSeVaRIrh2mIiIjaFsNINXMV1hIO0xAREbUphpFqzuaeEYYRIiKiNsUwUs3JvD8Nh2mIiIjaFMNINQ7TEBERSYNhpJozJ7ASERFJgmGkGodpiIiIpMEwUs28tLekvFLilhAREdkXhpFq5jkjpeXcm4aIiKgtMYxU45wRIiIiaTCMVOOcESIiImkwjFRz5JwRIiIiSTCMVHM2zxmp4JwRIiKitsQwUs28mqaMRc+IiIjaFMNINfOckWIO0xAREbUphpFqGicHAICutELilhAREdkXhpFqni4qAMD14nKJW0JERGRfGEaqeTibekaKy6tQXslJrERERG2FYaSaxtEBcpnp64IS9o4QERG1FYaRanK5DNrqeSPXSzhvhIiIqK0wjFjwcK6eN8KeESIiojbDMGLBvXreCIdpiIiI2g7DiAVxRQ2HaYiIiNoMw4gF9+phmnwu7yUiImozDCMWPDhMQ0RE1OYYRiy4O3OYhoiIqK0xjFgwr6ZhzwgREVHbYRixYB6mYc8IERFR22EYseDhwjojREREbY1hxELNMA17RoiIiNoKw4gFy9U0RqMgcWuIiIjsA8OIBfNqGqMA6MvYO0JERNQWGEYsqJRyuKgUADiJlYiIqK0wjNTi7aYGAGTpyiRuCRERkX1gGKmli7cLAOD8tSKJW0JERGQfmhVGYmNjMXjwYLi5ucHHxwczZsxASkpKo8/78ccf0bNnTzg6OqJv377YtGnTDTe4tYX7uAIAUnMYRoiIiNpCs8JIXFwcFi1ahIMHD2L79u2oqKjAxIkTUVxcXO9zDhw4gFmzZmHevHlISEjAjBkzMGPGDCQnJ99041tD106mMMKeESIiorYhEwThhtewXrt2DT4+PoiLi8Po0aNtnjNz5kwUFxdj48aN4rFhw4ahX79+WLFiRZO+j16vh1arhU6ng0ajudHmNsnRi/m4e0U8ArSOOLB0fKt+LyIioo6sqZ/fNzVnRKfTAQA8PT3rPSc+Ph4xMTFWxyZNmoT4+Pib+datxjxMc1VXhmJDpcStISIi6vhuOIwYjUYsWbIEI0aMQERERL3nZWVlwdfX1+qYr68vsrKy6n2OwWCAXq+3urUVd2cVvF1N9UY4VENERNT6bjiMLFq0CMnJyVi7dm1LtgeAaaKsVqsVb8HBwS3+PRpinjfCSaxERESt74bCyOLFi7Fx40bs3r0bQUFBDZ7r5+eH7Oxsq2PZ2dnw8/Or9zlLly6FTqcTbxkZGTfSzBvGFTVERERtp1lhRBAELF68GOvWrcOuXbsQFhbW6HOio6Oxc+dOq2Pbt29HdHR0vc9Rq9XQaDRWt7YU7OkMAMjSs/AZERFRa1M25+RFixZhzZo12LBhA9zc3MR5H1qtFk5OTgCAOXPmIDAwELGxsQCAp556CmPGjMG7776LqVOnYu3atTh69Cg+++yzFn4rLUfrZNowT1/KkvBEREStrVk9I8uXL4dOp8PYsWPh7+8v3r7//nvxnPT0dGRmZor3hw8fjjVr1uCzzz5DVFQUfvrpJ6xfv77BSa9SM4cRHcMIERFRq2tWz0hTSpLs2bOnzrF77rkH99xzT3O+laQYRoiIiNoO96axgWGEiIio7TCM2MAwQkRE1HYYRmzQVIeRsgojDJVVEreGiIioY2MYscFNrYRMZvqavSNERESti2HEBrlcBo0jl/cSERG1BYaRenDeCBERUdtgGKmHh4tps7y7lsdjc1JmI2cTERHRjWIYqcfsISHi12uPtO3eOERERPaEYaQe9wwKQkwvHwBAAYdqiIiIWg3DSD1kMhnmjewCACg2VErcGiIioo6LYaQBbo6mavlcUUNERNR6GEYaIO7eW8YwQkRE1FoYRhpg7hkpqzCivNIocWuIiIg6JoaRBriqazY1LmTvCBERUatgGGmAUiGHi0oBACgs4yRWIiKi1sAw0ggN540QERG1KoaRRpjnjbBnhIiIqHUwjDSCG+YRERG1LoaRRnCYhoiIqHUxjDTCw9m0YV5ecbnELSEiIuqYGEYa4e1qCiO5hQwjRERErYFhpBHermoAQG6RQeKWEBERdUwMI43wdjMP0zCMEBERtQaGkUZ4uVT3jHCYhoiIqFUwjDSCwzRERESti2GkETXDNOUoq6iSuDVEREQdD8NIIzydVVDKZQCAJ75LkLg1REREHQ/DSCOUCjmW3t4LALD9VDZSsgolbhEREVHHwjDSBPNGhmFSH18AwG9JmRK3hoiIqGNhGGmiASEeAIBLecUSt4SIiKhjYRhpohBPZwDApbwSiVtCRETUsTCMNFGIlymMZOQzjBAREbUkhpEmMveM5BWX471tKRAEQeIWERERdQwMI03k5ugAmWmFL/69KxXRsbuQpSuTtlFEREQdAMNIMyy7O0r8Oktfhs/3XZCwNURERB0Dw0gz3D0wCC4qhXi/pJwVWYmIiG4Ww0gzGS2miiSkF7BEPBER0U1iGGmmSqNR/PpUph6xm05L2BoiIqL2j2GkmYI9nK3ufxV/SaKWEBERdQwMI8304X390dPPTbzvplZK2BoiIqL2j2GkmfoGabFlyWgk/H0CAKDQUImS8kqJW0VERNR+MYzcIA8XFdwcTb0iVwtKJW4NERFR+8UwchOCquePZFxnGCEiIrpRDCM3IczbFEYOp+VL3BIiIqL2i2HkJtwRFQAA+PHoZVRWGRs5m4iIiGxhGLkJ43v5wkWlQG6RARdyi/HuthRsSc6SullERETtCtel3gQHhRxdfVxx4rIO728/i83VQeTiO1MlbhkREVH7wZ6RmxTeyRUAxCACAIIg1Hc6ERER1cIwcpO6+rjWOaYvY90RIiKipmIYuUndfd3qHMsvLpegJURERO0Tw8hNGtejE56f1ANfPzIEvho1ACCvyCBxq4iIiNoPhpGbpFTIsWhcOEZ37wR/rRMA4EpBKd7bfhapOYUSt46IiOjWx9U0LcjbVQUAeGptIgDgu8PpOPJyjIQtIiIiuvWxZ6QFebqorO5fK+RwDRERUWMYRlqQl6u6zrHySlZmJSIiagjDSAuK6eVT59jZbM4bISIiagjDSAsaGOqJv07uiQEh7lDIZQCAP320DyXlrDtCRERUn2aHkb1792LatGkICAiATCbD+vXrGzx/z549kMlkdW5ZWR1zD5eFY7vil7+MwBO3hYvHDl3Ix75zuej72lZsTsqUsHVERES3nmaHkeLiYkRFReHjjz9u1vNSUlKQmZkp3nx86g5pdCRLYrrj7oFBAIC4s9fw+DdHUVhWiYXfHhfP0ZVWYFNSJsoqqqRqJhERkeSavbR3ypQpmDJlSrO/kY+PD9zd3Zv9vPbstp4++OnYZRy8kIfi8prAIQgCZDIZFq85jt/P5WL+qDC8PLW3hC0lIiKSTpvNGenXrx/8/f0xYcIE7N+/v62+raT6BmoBAGeyrCexZuSXAgB+P5cLAPg6/lLbNoyIiOgW0upFz/z9/bFixQoMGjQIBoMBK1euxNixY3Ho0CEMGDDA5nMMBgMMhpoaHXq9vrWb2SqCPJxsHk/IuI73tqeI9w1c/ktERHas1cNIjx490KNHD/H+8OHDcf78ebz//vv45ptvbD4nNjYWr7/+ems3rdXJZDKr+65qJYoMlWKFViIiIpJoae+QIUOQmppa7+NLly6FTqcTbxkZGW3Yupb13MTuAIB/3hWJN6b3kbg1REREtx5J9qZJTEyEv79/vY+r1Wqo1XWrmbZHC8Z0xdTIAIR5u+BibrHUzSEiIrrlNDuMFBUVWfVqpKWlITExEZ6enggJCcHSpUtx5coVfP311wCADz74AGFhYejTpw/KysqwcuVK7Nq1C9u2bWu5d3ELUyrkCPN2AQB09nbBu/dEQePkgC/2pSH+Qp543sXcYhy5mI+v4y/hL2O7Ykrf+sMaERFRR9LsMHL06FGMGzdOvP/MM88AAObOnYtVq1YhMzMT6enp4uPl5eV49tlnceXKFTg7OyMyMhI7duyweg17cld17ZHVB61X0Iz91x7x64XfHsfFd6a2ZbOIiIgkIxMEQZC6EY3R6/XQarXQ6XTQaDRSN6dFPPrVEew4nVPv41FBWjw2uiumRrKHhIiI2qemfn5zbxqJOCgavvR/XNbh4931T/IlIiLqKBhGJOLpohK/VsplNs+5lFcMc8fVkYv5SMniDsBERNTxSLKahoCnYrohIb0AMwcHY1Q3b8SdvYbX/3fK6pzi8irkFpWj2FCJe1bEAwDOvTWl0V4VIiKi9oRhRCI+bo7Y9NQo8f7l66Xi1128XXApvwRVRgGD39ph9byE9AIMCfNss3YSERG1Nv6KfYuwLB2/89kxCPV0tnnetpNZbdUkIiKiNsEwcovo0skV78+MwjfzhkAmk2F6v0Cb532xPw2nM9vnXj1ERES2MIzcQu7sH4RR3ToBAJ64LRzRXbzExz68rx+GhnnCKJiGaoiIiDoKhpFblFwuw2CLuSHRXb3Q2ctUyTWvyFDf0+r18e5UrIg732LtIyIiaimcwHoLUylqlvx2clXDy9W0HDivuLxZr5NXZMCyrSkAgFlDQqB1cmi5RhIREd0k9ozcwszzRoZ18YRMJoOXq2nzwAPnc/HaryeRW2RATmEZjMaGi+hm6srEr69YrNohIiK6FbBn5BYW7OmMwy+Nh9bZ1JPhXd0zcja7CGezi7DqwEUAQJi3Cz6+fwB6B9gutXuloCaAXL5eUu95Zul5Jci4XoIR4d4t8C6IiIgaxp6RW5yPxhFqpQKAddVWS2m5xVi85ni9r5FpEUbMwaS80ghdaYXN80cv243ZKw/hePr1G202ERFRk7FnpB3xclHX+9iF3GLoyyrww5EMyGUy+GjU6O2vQZdOrrhqMUxjLq62YPUxHE7Lx/pFI/C39UlwUMjh6KCwmhx7IDUXA0I8Wu8NERERgWGkXTEP01hyUMgggwzlVUb8d+8FfLSrZnM9D2cHBHk4I+mKTjy283Q2Fozpil1nTDsGP/19otXjlsoqjC38DoiIiOriME074umiQqC7k9WxiEAtOnubqrX+cvyK1WPXSyrqBI2LeSV48ecT4v36gggA6MtsD+MQERG1JIaRdkSpkON/T4zE3ufHicdmDgqGv9YUUCwnqtoyvV8AAGBnda9IY64W1AzvVFYZ8cGOszh0Ia+5zSYiImoQh2naGU8XFTxdVPjn3ZFIyy3GPYOCsf+8dUDY/+JtiEu5hrd+O4Xi8irx+B1RAdiQeFW8P72f9f3aUrL1+PfOc4g/nwe5HNifmocPcA4X35naaDurjAIUclmj5xEREbFnpJ26d1Aw/jq5JxRyGZwdFOLxFQ8MRKC7E+4fGoKk1ybhsdFdAACRQVpEBGqtXuPZCT3goqp5rkph/c8hI78U720/i/gLedifWhN4Zn4aj2x9Gerzxv9OYcCb25GRX3JT75GIiOwDw0gH8GRMN4wM98a3jw7F5Ag/8bhcLsOzE7vjzRkR+HzuYPi4qRHqZZpfEqB1RLCnE3Y+OxYrHhiAtNjb8cLkHk36fofS8nH/fw9aHbt8vQSxm04jt8iAL/anQVdagZ+PX265N0lERB0Wh2k6gEB3J6x+dKjNx9RKBR4cFire/3nhcKz8PQ3Du3pBJpPBT+uIyVp/AEAnN9tLh1/5U2+8sfGU1bHz14px/loRunZyxZ6UHDz05REAwPdHM8RzHKp7Wg6cz0W4jyt83Bxv/E0SEVGHxTBiZ7xd1XhxSk+bj/X210AmAwSL6vJuaiUeHtEZmbpS/Pf3NKvz715+APNGhuFf286KxwpKalbgXCs0YNeZbDyy6igC3Z2w/8XbWvbNEBFRh8AwQqJuvm7Y9OQoKOQyuDs7YNmWFDw6qgtkMhlentq7Thi5XlJhFURqy9KViRNkG1vp0xSXr5fgxZ+TcO/gYNwRFXDTr0dERLcGzhkhK738Neju6wYfN0csuycKPfzcxMcWjesKAHh+Ug/cOyioznN7+2uwZv5Q/OOuvgCAHaezG1ytU1hWgYIS2zsQl5ZX4YWf/sDO09nisXXHr2Bfai6e/C4ByQ3URyEiovaFYYSa7OmY7tj4xEgsGNMVQ8K8xOODQj2w8YmR+PbRoRje1Ru9/E0b8VXW2k24rKJmmbHRKCA6dheGxe60Om728e5U/HD0MuZ9dVQ8lpZXLH597NJ18TW3n8pGSXnlTb03QRAQd/YarhUaGj+ZiIhaFIdpqMmUCrm4PLi3f83Ov/cPDbFaNly7SqxZXnE5HOQyXMwrQbCnE4oMpgBx/loRQjyd4ebogJzCMjjI5Th5tabnI/mKDu9tPyuWsAeAczmFqKgy4q3fTuObg5cwc1AwZg0NwebkTDwd0x2O1cudBUHAv7alwE/jiAejO9dpU2JGAR7/5ij8tU5IzChATC8frJw7+MYvEhERNRvDCN2Q7r6uGN7VCyqlHNP7BVo95uWqxtt39sVr/zuJ8sqa/W3yi8rx5sZTOHwxHw8N7ywen/rvfXBTK/HVvCF46IvDcHRQWBVMm/fVEWTrrXssVh9MR5VRwHeHTat3vj+agZ+OX0aVUUCJoQqPjAyDrrQCTg4KfLz7PAAgW2/AkphuUCrkOHYpH39ffxKnMvXiYwCw43QOxizbjckRflg6pVfLXTAiIqqXTBAEofHTpKXX66HVaqHT6aDRaBp/At0SBEHAiz8nict9v3x4MB6uXgLcmpxVCriolcgvLsezE7vjn1tSxMfeuzcKfx4QhG4vb0JFVcP/9L98eDDG9fBp0bYJggCZjJVpicg+NPXzm3NGqNXIZDL84+5IjAz3BgDsbuKeOLb4aRwxvKsX/DSN1yopKa/CtUIDqoyCVRABgL1nr6G0vKrRIAIAD395BMcu5d9wm2vTl1Vg9LLdWLI2ocVek4ioI2AYoVbn5aoCAHwdf6nec6ZYVI41s5x70t3PDWvmD8N/5wy6oTaYC7rtS83F5/suNPl5xy8V3ND3MxMEAb+fu4bdZ3KwPuEKMvJLsT7xKlqyQ7Kyyoidp29+Ei8RkVQYRqjVNWWo4/ExXcWvn5vYHctnD8BT47uJx3yrw4SPxnaVWLPBnT3w04JovHVnBGYNCal5/dFd4O2qRm5RuVVtFGeLvXlsaWxERRAEHL2YbzMI/Hg0AzM/PYgHPz+MeV8dwR8ZNZNy9WXW58edvYaVv1+A0WIFUl6RAVm6+vcAMnvl15OY99VRLNua0ui5bS01pwhpucWNn0hEdo0TWKnVzegfiKmR/njw80M4eCEf0V28cPhiPqosPng7eznjn3dHYveZHDwyMgzOKiXizl4THzeHEC8XlXjsj1cnYlNSJpb+kiQeG9vDB4M6e2JQZ08YKqvw3eF0AKZeln/P6odHVh1BWYURUyL88MF9/VBWYURukQEHzufhyvVSrIg7b9X23KKaOigLVx/DlYJSrH50KJ5Yk4Aefm4YEe6NuV8cRmSQFj8vHI5DF/Lh5qiEg0KO5386IT7XKABbkjPF+9n6MmidHACYejbmfnEYgKkHZ3q/QJy6qsddyw/AQSFD/NLxcFGb/qsajQJkMojzTt7ZfAZrDpne45f7L+LVaX2a/fdzIzYkXsHBC3l47Y4+UCttB7piQyVi3osDAJx7a4q4PQARUW0MI9QmHBRyfP3IUBy8kIdBnT3goJDjrd9OY9WBiwAArZMD7h0UjHsHBYvPsZwf4lv9tVIhx7anR6OySoDWyQGzhoTgTKYe6xOv4rmJ3TF7aM0+PGqlAiseGIjDafmY0NsXSoUce54bh8SM67itpy9USjnUSgW0Tg7o2skVx9Ov1wkjOdW7E2fry7A5OQsA8K+tKYg7ew1x1fNPAODEZR3uXn4Af1zWQSYDRnXrVOcaFJfX1FOZ+P5edPNxxbyRYejSyVU8/v2RDEzvF4i3N51GaUUVSiuAPq9uRW9/DSICNVifcBUDQz3w9z/1hqGyyqq9IZ7OzftLuQlPrU0EAIR5u+Cx0V1tnmPZI5KlK0OwpzNKy6vg1EhvFBHZH4YRajMqpRyju9d8SFdU1Sz7tbXCxDKMmHsRAKC7r5vVea9Pj8Cr0/pALq/7GpMj/Kx2MrbcGLC2fkHu8NM4IktfMzRyLqcIxy5dF0MJYD335Xj6dfHrPy6bhmEEwTRRtjHncorw4i9JcLH4cD54IQ9pucU4eCHP6txTmXpxGXL8hTzc/u/f67xeen4JNiReqbPUuinOZhcixNNZrM9SH11pBZ7/8Q/x/tGL1/HYaFPF3FOZegwIcRf/Li9aFKm7WlCKgxfy8MLPJ/CPuyKtQmdrWPn7BRQZKrEkpnurfh8iahkMIySZXv4NL9PWONX886xvKMDMVhBpLrlchi1LRmHrySz89WfT0E/SFR3uWn6g3raevKq/6e9r2WNiFIBx/9pzw6/11NpETI7wa/R6Hb2YjysFpfByUSNTV4rnfzqBx0d3wV8n96xzLQ2VVcgrKsfGE1fx9qYzVo+dztJDEAS8sfEkvjucgUdGhOGVab0BAGnXasJI3Nlr+GSPqRfnhZ9OIFtXhnsHB4s9Xi1BEAQs25oCV0eluIpqcoQfevrdXDmAIkMlvjpwEbf39UeYt0tLNJWIamGdEZJMZZURn+69gJHh3ogKdrd5zpsbT+HE5QKsfnRoox+wLelKQSlG/mMXmvq/4717o1BQUoHR3b0R897eZn8/lVKOBWO64t87zwEAlHIZHBRylFaXyv987iCr0viWOns542JeiXh/05Oj0Dug/v8nlVVGhL+82eZjbmolZg8LtdrZ+fkf/8BPxy/Xey3u7B+IdQlXxPu/vzAOwZ7OWLj6mDi0Zcvgzh74ccHweh9vrlNX9XV6jD68r1+dnqLKKiOUzZi/8tqvJ7HqwEV0clPjyMsxLdJWInvR1M9v9oyQZJQKORaNC2/wnL//qXcbtcZaoLsT/vvgIHy+Lw3xtYZManNRKfCnyAColE37gBvVzRtBHs7i5FoAmNrXH0/eFo6oIC30ZRXoH+yBr+Iv4sv9FwEA0V296nk1wFll/d/4TJa+wTBiHu6xpdBQiU/3nsfUvv7w1ajhqFLgx2OXG3w/lkEEMJXvzyk0NBhEAODIxZr9hRobHmqKlOy67+tUph7T+wVCEARUVAn4z65z+HTvBbx2Rx8kpF/H85N6isu+AeCjnedQYRTwdEw3cbjJvA3BrbRvUaauFIvXJOCh4Z0xjTtYUwfAMEJUj5jevhgc5omo17eJx468HINObmoUGSrxdfxFFJZVYnIfvwaDyITevigoKceRi9fxzITueLJ6yfL/zYjAtI/2Ib+4HEtv7wmlQo7xvXzF51mW0ndWKTG1rz9+S8qs8/oB7o5WAeNMVmGD7+twWsOF3AQBmPaffQ2eAwBujkrcNSAI6xOvoLJKgEIug660Aqcy9TjRxF2VX1qXhO+PZOCHx4dhYKhno+dfLSjF6Uw9buvpU2eeUfKVumHk4IV87EnJwUO1Kv+aV2AVl1fh4/sHADDN/3l3u2nZ9/R+AehaPbHYctTqgZWH8NUjQ6y2K2iM0Shg68ksDAj1QEl5FT7bex4LxnRFqNeND/m8/uspHLt0HccuXRfDyJksPdydVPDTWg996csqUFpe1aJDYh1ZRn4Jdp3JwczBwS0SkqWUX1wOdyeHFhnGbm0MI0QN0Do5wM1RicLquiDe1QXcXNVK/GWs7V6dEeFe2J+aBw9nB6yZPwxh3i6oMgrYkpyF2/vWTJ5VyGX43xMjUV5ptLnC5Pa+/vj2ULo4T+Hde6PwyMgw3LX8AABgYKgHPJxVeO2O3thxuqa6bWpOUZ3XKimvxH/3pmFalD8ONRJGmuqbeUPRL9hdHNL5/kgGXv31JHaczkFekakXYUlMN2Try7An5Royq2umhHg6Iz3fNKxkXpa843ROo2Fk79lrmFO9BPq/cwZhUKgHnlybgMkRfpg9NBQnLhfUec4fGQV1goil305kYvE4PXr5a/D572ni8fHvxuGLhwbhtp6+sByd2peaiyvXSxHi1fSVS98dScfL65LRy18DJwc5jqcX4LvDGRjf0wePje6CoV3q9noJggBDpbHeD0PLycEAcORiPu5ZEY8evm7Y+vRoq8fmfnEYZzILsePZMfVuYtkcGfklWJ9wBXOiO0PrXDOx/D+7zqHSKFhNGq6sMmLryWyUVlThzv6BzQpxbeV0ph7nr5n+z4zu3gn3rzyIjPxSXNWVtuv9qZIu6zDtP/swvV8APryvv9TNaRQX/hM1op/FfJam7Cuz7O4o3D0wCKsfHYpe/ho4Opj2yrlrYFCd0KGQy+pd6joi3Bs/L4zGzwtN8yocHRQYGOohFnN76faeWDl3EII8nPHpgwPhX/0b8a4zOZi98qDVaqW/rUvG+zvOYuHq4zhy0TqMvHR7T/x7Vn+M7t4J3z46tM4HRkwvH+x8dgzGWKyE+mT2APG6ODoo4OigQGSQaefm05l65FQPafQJ0CL2z5H4/rFoqJVy3BEVgL0vjMPkPtYVd/WlFQ1e05NXdWIQAYDdKTn4aFcqfj+Xi5fXJePOT/aLwz7NNeXD3xG7+TQ2J1v3Oj2y6igOpObWKTx34HyuzdfZdjILU//9O1Jq9Ux9fcC0+up0ph4nLtf0GO08k4P3d5yFLa9sOInI17fh5XVJuFpQWufxEotJzzM/jcc9K+IBACnZhUi+osP0/+zDD0cykJhRgIT0ApRWVGHX6ez6LkGzPP7NMby7/SyWrqupo3Mprxj/2nYWH+w4h5zCmuu1IfEqFq05jud+/AM7Gvn+OYVlWPn7BZuhsjVN+fB3LF6TgMVrEvDyumRk5Juu9/paw4/tTdxZ0y8oGxKv4ujFltvWorWwZ4SoEW/f2Rfzvz6KOdGdm3R+gLsT/nVPVIt8b1u9Bf83IwJPT+gGH7eabvdJffzQxdsFE943TZ7dn5qH1QcvIcjDGdtPZeGX6h+sKdl1h3DMdULuqO7u/+HxaGTry6BWyjGqWydxCOqT2QPQ59WtAGAzQPULdrfqRQJqap+EeDnj2N8nwLn6N/3bI/2x5WTNnJL6Ks3uO5eLIkMFrpdYhxVzj4pZQnpBned6u6qsitaplXIYKo14bHQX9PJ3w9Pf1yxR/jTOtEWAh7OD1fe6f+WhOq/74i9JcFIp6kyMfeybYwCApb+cwC9/GQHANLcj43rN5OJKo/UsYPNu0UajgM/3pWFAqAcig7T45qApwHx7KB3fHkqHTGbqKZszLBQ/H78s9iwBqNPT9dyPf+BMViH+uHzC6vjpeobvDqTmIsDdCZ1trBS6WlCKD3acxYIxXeHtpoabWikOCW5Kqvn725NSs5Q9S1eGtGvFeHl9slUv3YVrtivxVhkFfB1/Ea//75R4bHBnD3z58BC4qlv3I8pY6+/jf39cFb/Os/i305Ckyzr8nnoNE3v7ItzHrfEntBHLVXpvbzqNnxcOr/PLlCAIEISWWY14sxhGiBoR7OmMLUtGN35iG1HIZVZBxCygVhe85Q/32mJ6+eLugYHivAhLA0M9bD7HxeKDoZtP3efJZDJ8MnsAHvy8pgcj2LOmTZYfLDG9fKyGazItwsivf1zF1uQsvDkjAg98XjcM1MdNrcSSCd3x5kbT+x4R7o0NiTUfLodeGg9DpRGeLio4KOSQy2Q4lanH4bR8JKQXwEWlwI8LhkPjpISLSokXf0kSP5xGhHvh1FW9GFSeWptYbz0Xc8DYkpyFBauPNdjmtNxifHXgIqqMAt7adBoAbM4/EgTTkNJvJ+rOGaqtvjlDxy/V7Tk6dVUvBq6L70wVj5eWVyH+Qi4+2pWKhPQC/HD0MmQy4LHRXaye//HuVGw7lQ2lxYfZztM5uFpQWme4MEtXitjNp3ExtxivTuuD0ooqeLuo8fam0+LO3mZHLl7HuoQreHBYKG6EeYLvnOjQOn9PgiDg5FU9unZyRXED+zmZg2NqThH+tTUFMwcHY1xP660tig2VeHjVEeQWGfDPLSmYNSQEf5vay+r/SnO8ufEUTl3V4+t5Q8SKxVtPZkEQBEyOsF0fqT45+poJ18fTC/Dp3guYEx1qNeH94VVHkJFfgt+eHCX5/BiGEaIOoqEfgC9M7oHV8ZdwVWcqQ//2nRHwuYEJjbufG4v84nIEedieMzGqWyfEPT8WY5btQYinc52VPmbOKiW2PT0aZ7MLccd/9uNUph5LfzmBucM748nvTLsaN2V+wb2DgvDTsctY8cBATKwe+jGHkbKKKkQFafHHZR3cnR3g7qyyeu70foGY3i8Qhsoq7E/NRWcvF6tquB/N6o87+wfgXHYRHhrRGf+38bTYYwEAm5MyMaWvP9Jyi8UVN4BpWfjhtHw81cDuzL4atRhaXv31pNVj5onL0/sFYNeZHKuepubwdlVBLpNh5dxB+PMnB3AmqxBnsvRWdVcSMwrEryurjNidcg3fHLwEtVKO7aesh1UEoaYHyczWfkgfVi9Pr+14egGSqic2bz1pCjCd3NRWQdTS4bR8XC8ux4PDQuHhorJ5Tn3e23ZWnOBrDiNVRgF7z13Dleul+Nv6ZMT08sHTExouirch8YpYbXjLySxse3o09qfmYn9qHgRBwPlrRcgtqvnQ/+5wOr47nI4Xp/TEgjG2KxPXRxBMvWOAaW7SuB4+yCksw+PVPW7myfOW5//15xMQBOCfd0dCJpMht8iAHL0BvQM0VsNlgGnriMvXS+CnccR3hzOw6uHBYo/W8fTrGN7Vu1ntbWkMI0Qd3P1DQ/CXseEY3tUbm5Iy8fjoLvBybXjDwfqEebs0Wvgr1MsFu58bCxd1w79pOTooEOpZ81rfHc7Ad4drfkP+1aLL3JbnJnbH4tu6IfbPkTaDS99ALf51TxTe234W/UNs9/YApoJ6t/X0tfnYbT19xcfuGRRkFUYWf5eA+8/nWR0zu/fT+Abb/tOC4Rj1z931Pt7Tzw0vTO6Jh0eEYcbH++s8HurljEsWtWX+PCAQj4/uigWrjyEttxivTeuNWUNDIAim6zyxjy82JWVh3fErWHp7TRgxWhSPySk0YP7XtmvZtISkWiusKo1CvUEEqBky2ZSUiU1PjmrWUEKOxTLsjPwSrD50Cdf0BnG4EjBNmr5/aIitp4vMQcRszaF0cQuLhryz+YzNMGI0CrhSUIogDydUGQUs+T4RfhpH/O1Pva2GVQzV9YX2p9bMT0rMKMCE3jX/TjPyS/HDUdOye11pBf51bxQeWHkIZ7IK8eF9/cSeEbnMVEwRAFYfrBneNA/pAk0fkmpNDCNEHcikPr7YejIbr9/RB+WVRsweFiL2TvQLdreajNuamlqpVOOkhJtaiUJD47/99w3U4va+/vjHFlMV2LsHmkrK1w4iO54Zg22nsvDw8DA4qRQttnlgZJA74p4fC08XFR5ZdQRHLl63GUSawruRMGgeFgx0d8KjI8Owcl/NSp+oYHdsWDQCnV/8TTz23r39AACbnxqFbH1ZnWXDQ8O8sCkpy2r+yo5T2fjDomfkio2JsgDgoJChoqr1a2P+KdIfG20MQ53JKkRusaHO0GTs5tNQKeR4cnw3bD2ZhYgALTp7u+DCtSKrTTaf+SGx3snNKVmmoaTbevpApZBbzWOyZF4hZyuIRAZpERmktfqgB0yVi2sXavzXthR8suc8VjwwAB7OKvH9Pj6mq1UvlXmC8u9na8JIQvp1eLmq8Ld1yXhzRoRV3Zttp7Jxz/J4cU7Y39Yli+GtX7A7jtuYU2XpwrViZOnK4OWqkmxDS66mIepAPpjZHxufGIk50aGYP7pLvcMktwqZTIa+1atwzCKDtFZzEMx+XTwCC8d2xcf3D8AXDw2qU0/DLNzHFX8ZG94qG/KFernAzdEBD9SayzBriO29doaGeeLtO/vWOe6kUmB4dSG7h4Z3hqqBD4DaC7h6+VlPkrQMY44OCpv1S8zd++YPsPjzeXj066NWBe3mWMz1MXtmQnecfmMy/nl3JL58eLBVO3c+OwYPDe9cb7vNc48s915yUNTfu/Hm9Agsnz3A5mMPrDyE89eK8Pm+NBSWVSAlqxCfxl3AR7tSMeofu7F4TQKerB4Wq72Uu6FVVuZg66tR4917bU8693B2wDt/jqz3NTq5qvHC5J64a0CQ1XHLnivANKxi3hLh/e3nrFZWvbP5jFWv1DM//IHkKzqrSceJGQWY9dlBnMrU46EvD9cpXmg5Ob3QUAld9Qq1yCD3ettu9v6OsxgWu1Nc4iwFhhGiDsRJpUBEoLZJS5BvFbWLcQ0M9cCa+cOw7i/Dsff5cXBTK/HAsBDxPU2N9K93WKWtBFvskDy5jx9i/xyJAy/eBudaAeieQcG4f2gIPryvX53X+M/9A/Dzwmi8dkcf7HpuDOaPCoNchjrnDu5svaLKvOrJ/GE/f5T1pFJbLMOIIAg2l3qatx4we3RkGJ4c3w1KhRz3DgrGuB4+6O5XM6emaydXvHZHH6vgODe6JqQ9Nb4bDr00HgeWjkdML18M6+KJwy/FYHT3Tpg3MgwqpRyODnLMHxWGj2b1h4eLClP6+uP43yfATa1EiKczelYHr7PZRRj/bhze3HgKfV/bhkkf1AwxmDe2PHFZh4KScquVRpYm9/Grt6JzqJcLXNRKeNqYmxLs6YxgT2f8KdL2BFJPFxU0jg54994obHpylHj8Qq0Pdsvw4KNRW22y+fPxulWO//TRPlyyqCdzNrsIhur5RIVllUiqZwm0ZY4f0tkTGotNRhtzvbjhJfat6db+tYmIOrzR3b2tSsr7uDliSFjNB3DCKxOatZdMWwi1CCOBHqYVQwHuTtjxzBgUGSqxNTkLh9LyMamPKTRN7xeIkeHemPbRPnFVhKeLCp4upvcZ5OGMl6f2xrMTe9RZ1TChty/euzcKrmolnFVKDA83TTT8ZPYAxKVcw4z+je/S3Kl6WOhiXgn6vbEdTk1YORHkUbdA2vOTemLuF4cxxWInbK2TA/KKTXMOXp8egTsHBCFbX2a1Q/fKuYPEr79+ZAgAYObgYLiqlXVWgXm6qLDzuTFQKxR49sfERisKW4qO3VXvYx/PHlBnN2zA1Hs1t3rZvtbJAfnF1vMnzMHztTv6QOvkgP4hHgh0d8Ks/x40tde1JsD0DtDgz/0D8UvCFZzNLsLkCNPx1JxCzFtV0/ORV1SOc9mN90JY1pOxnCgLALurJ59ufGIkPthxDjtOZyNA64hp/QLEicYf3NfPqrdDIZfhnT/3hUopt5oPMyLcC189PETS/2cMI0QkqRn9AmGoMOLF6hLtXq7Wv53eakEEgNVv0JZj7OYP1u6+bnii1nO8XNXY/+JtDfZa2VpeKZPJ8OdaQwCAqUfp3sG2h4dqs1yFoSutELvwzf7cPxCTI/yw5WQWfjluCoaWvT9mY7p3wo5nxsBXU/N6c6I74/0dZzG4s6mnpqnzkrr71l+Twzw/xMO58VU0Yd4uKDZUIqfQIPbuBHk44eP7ByAluxAv/HQC43p0gkIuQ6hF5VxHBzkqqgT8465IcUhvRr/AOoXozMHT21WNt6qH3M5aDIl41epN6R2gwS8JV/B1/EX8dOwyXp/eB6//elLswQFqekmUchmGh3tjr8Ucl7rXQm01Ibe2PgEafDJ7ANYeSUdUkDv8tY64pjdg9rBQBLg7wV/riOWzB6BLJ1d4uqjEfwtje/jg4IU8jAz3hpODQvJaIwwjRCQpmUyG+4aEoKC0AvHn88RhiFuZTCZDgNYRV3VlmNin6UNGUg2f1bfs+43pfRDu4you6xwS5imGkfr2sgmvVWNm4diu6NLJRZwD05IMFvsz2TIo1APfzh+Kz/el4Z9bTMuMB4S448P7+iPY0xlRwe4Y3NlT/AAO8nDGy7f3gpNKgWFdPFFSXmVV7G3h2K7w1agR3dUL3x/JwNWCUpsrbiwDiLbWMIh5joap4F45Hm5gO4LO3i747MGBuJRXYjX0ZMnDWYXCsso6w2gA8NfJPSGTyaBSyqyKMr43s5/4tUwmw5S+dYeYtE4OmFSrErKUGEaI6JawYEzXZtdmkNKGxSORcb0EAxpYNnyrG9Wtk9XKJ62TA8J9XKEvragTOuqjUspbbedgN8eaj6g1jw7FkYvXrXouunZyhVqpwPxRXTC6Wyf09tfU+Q2/9squ+aPrn2OjUspxX/V2Cy9M7lnveZY1a2oHzD4BGqvltLXfjwyAvrp2zJjuneDooEAPPzcsHheO/+xOBWAKfG/f2RevbEjGE7d1g6ODHB/sOId/3h2JyioBpzP1uKNfgOSFyloSwwgR0Q3o5Ka2Gv641ZmXzs6NDsX2U9noH+qBzrU2/JPJZNj4xEgYBeGW+KB74rZuOHlVj/uHhmB4uDcGhHrgj8sFYpE58/5LDgo5IgK1Db1Ui7JcwVS7B8lFrURPP02d1S4OChk2PTkKBSUVOHnVtJJmssXcm2cndsec4aEwVBjh4aKCq1ppVfnZckfv2ivQOgKZIAitv4D8Jun1emi1Wuh0Omg0msafQEREViqqjMgtMsBfe/M790rNXGNl8bhwPDephyRt+O1EJk5n6vHsxO51ekeW/pKE7w6b6o48O6E7Tmfp8eCwzohuhaGsW11TP7/ZM0JEZAccFPIOEUQA4IuHBuHXxKt4fEzjy5pby9RIf0ytZ7nvpD6+Yhh5Yny3tmxWu8WeESIioha2KSkTIZ7ObTp8dCtizwgREZFEbrexgoXqd+st4CciIiK7wjBCREREkmp2GNm7dy+mTZuGgIAAyGQyrF+/vtHn7NmzBwMGDIBarUZ4eDhWrVp1A00lIiKijqjZYaS4uBhRUVH4+OOPm3R+Wloapk6dinHjxiExMRFLlizBo48+iq1btza7sURERNTxNHsC65QpUzBlypQmn79ixQqEhYXh3XffBQD06tUL+/btw/vvv49JkyY199sTERFRB9Pqc0bi4+MRExNjdWzSpEmIj4+v9zkGgwF6vd7qRkRERB1Tq4eRrKws+PpabyTl6+sLvV6P0tJSm8+JjY2FVqsVb8HBTduZkoiIiNqfW3I1zdKlS6HT6cRbRkaG1E0iIiKiVtLqRc/8/PyQnZ1tdSw7OxsajQZOTrZLE6vVaqjV7WcDKiIiIrpxrd4zEh0djZ07d1od2759O6Kjo1v7WxMREVE70OwwUlRUhMTERCQmJgIwLd1NTExEerppU6ClS5dizpw54vkLFizAhQsX8MILL+DMmTP45JNP8MMPP+Dpp59umXdARERE7Vqzw8jRo0fRv39/9O/fHwDwzDPPoH///njllVcAAJmZmWIwAYCwsDD89ttv2L59O6KiovDuu+9i5cqVXNZLREREALhrLxEREbWSDrVrrzkvsd4IERFR+2H+3G6s36NdhJHCwkIAYL0RIiKidqiwsBBarbbex9vFMI3RaMTVq1fh5uYGmUzWYq+r1+sRHByMjIwMDv9Y4HWxjdfFNl4X23hdbON1sa2jXhdBEFBYWIiAgADI5fVPU20XPSNyuRxBQUGt9voajaZD/eW3FF4X23hdbON1sY3XxTZeF9s64nVpqEfE7JaswEpERET2g2GEiIiIJGXXYUStVuPVV19l6flaeF1s43WxjdfFNl4X23hdbLP369IuJrASERFRx2XXPSNEREQkPYYRIiIikhTDCBEREUmKYYSIiIgkZddh5OOPP0bnzp3h6OiIoUOH4vDhw1I3qVXt3bsX06ZNQ0BAAGQyGdavX2/1uCAIeOWVV+Dv7w8nJyfExMTg3LlzVufk5+dj9uzZ0Gg0cHd3x7x581BUVNSG76JlxcbGYvDgwXBzc4OPjw9mzJiBlJQUq3PKysqwaNEieHl5wdXVFXfddReys7OtzklPT8fUqVPh7OwMHx8fPP/886isrGzLt9Kili9fjsjISLEAU3R0NDZv3iw+bo/XpLZ33nkHMpkMS5YsEY/Z63V57bXXIJPJrG49e/YUH7fX6wIAV65cwQMPPAAvLy84OTmhb9++OHr0qPi4Pf7ctUmwU2vXrhVUKpXwxRdfCCdPnhTmz58vuLu7C9nZ2VI3rdVs2rRJePnll4VffvlFACCsW7fO6vF33nlH0Gq1wvr164U//vhDuOOOO4SwsDChtLRUPGfy5MlCVFSUcPDgQeH3338XwsPDhVmzZrXxO2k5kyZNEr788kshOTlZSExMFG6//XYhJCREKCoqEs9ZsGCBEBwcLOzcuVM4evSoMGzYMGH48OHi45WVlUJERIQQExMjJCQkCJs2bRK8vb2FpUuXSvGWWsSvv/4q/Pbbb8LZs2eFlJQU4aWXXhIcHByE5ORkQRDs85pYOnz4sNC5c2chMjJSeOqpp8Tj9npdXn31VaFPnz5CZmameLt27Zr4uL1el/z8fCE0NFR46KGHhEOHDgkXLlwQtm7dKqSmporn2OPPXVvsNowMGTJEWLRokXi/qqpKCAgIEGJjYyVsVdupHUaMRqPg5+cnLFu2TDxWUFAgqNVq4bvvvhMEQRBOnTolABCOHDkinrN582ZBJpMJV65cabO2t6acnBwBgBAXFycIgukaODg4CD/++KN4zunTpwUAQnx8vCAIppAnl8uFrKws8Zzly5cLGo1GMBgMbfsGWpGHh4ewcuVKu78mhYWFQrdu3YTt27cLY8aMEcOIPV+XV199VYiKirL5mD1fl7/+9a/CyJEj632cP3dr2OUwTXl5OY4dO4aYmBjxmFwuR0xMDOLj4yVsmXTS0tKQlZVldU20Wi2GDh0qXpP4+Hi4u7tj0KBB4jkxMTGQy+U4dOhQm7e5Neh0OgCAp6cnAODYsWOoqKiwui49e/ZESEiI1XXp27cvfH19xXMmTZoEvV6PkydPtmHrW0dVVRXWrl2L4uJiREdH2/01WbRoEaZOnWr1/gH+Wzl37hwCAgLQpUsXzJ49G+np6QDs+7r8+uuvGDRoEO655x74+Pigf//++O9//ys+zp+7NewyjOTm5qKqqsrqHz4A+Pr6IisrS6JWScv8vhu6JllZWfDx8bF6XKlUwtPTs0NcN6PRiCVLlmDEiBGIiIgAYHrPKpUK7u7uVufWvi62rpv5sfYqKSkJrq6uUKvVWLBgAdatW4fevXvb9TVZu3Ytjh8/jtjY2DqP2fN1GTp0KFatWoUtW7Zg+fLlSEtLw6hRo1BYWGjX1+XChQtYvnw5unXrhq1bt2LhwoV48skn8dVXXwHgz11L7WLXXqK2sGjRIiQnJ2Pfvn1SN+WW0KNHDyQmJkKn0+Gnn37C3LlzERcXJ3WzJJORkYGnnnoK27dvh6Ojo9TNuaVMmTJF/DoyMhJDhw5FaGgofvjhBzg5OUnYMmkZjUYMGjQIb7/9NgCgf//+SE5OxooVKzB37lyJW3drscueEW9vbygUijqzubOzs+Hn5ydRq6Rlft8NXRM/Pz/k5ORYPV5ZWYn8/Px2f90WL16MjRs3Yvfu3QgKChKP+/n5oby8HAUFBVbn174utq6b+bH2SqVSITw8HAMHDkRsbCyioqLw4Ycf2u01OXbsGHJycjBgwAAolUoolUrExcXh3//+N5RKJXx9fe3yutji7u6O7t27IzU11W7/vQCAv78/evfubXWsV69e4hCWvf/ctWSXYUSlUmHgwIHYuXOneMxoNGLnzp2Ijo6WsGXSCQsLg5+fn9U10ev1OHTokHhNoqOjUVBQgGPHjonn7Nq1C0ajEUOHDm3zNrcEQRCwePFirFu3Drt27UJYWJjV4wMHDoSDg4PVdUlJSUF6errVdUlKSrL6gbF9+3ZoNJo6P4jaM6PRCIPBYLfXZPz48UhKSkJiYqJ4GzRoEGbPni1+bY/XxZaioiKcP38e/v7+dvvvBQBGjBhRp1TA2bNnERoaCsB+f+7aJPUMWqmsXbtWUKvVwqpVq4RTp04Jjz32mODu7m41m7ujKSwsFBISEoSEhAQBgPDee+8JCQkJwqVLlwRBMC0xc3d3FzZs2CCcOHFCmD59us0lZv379xcOHTok7Nu3T+jWrVu7XmK2cOFCQavVCnv27LFallhSUiKes2DBAiEkJETYtWuXcPToUSE6OlqIjo4WHzcvS5w4caKQmJgobNmyRejUqVO7Xpb44osvCnFxcUJaWppw4sQJ4cUXXxRkMpmwbds2QRDs85rYYrmaRhDs97o8++yzwp49e4S0tDRh//79QkxMjODt7S3k5OQIgmC/1+Xw4cOCUqkU3nrrLeHcuXPCt99+Kzg7OwurV68Wz7HHn7u22G0YEQRB+Oijj4SQkBBBpVIJQ4YMEQ4ePCh1k1rV7t27BQB1bnPnzhUEwbTM7O9//7vg6+srqNVqYfz48UJKSorVa+Tl5QmzZs0SXF1dBY1GIzz88MNCYWGhBO+mZdi6HgCEL7/8UjyntLRU+Mtf/iJ4eHgIzs7Owp133ilkZmZavc7FixeFKVOmCE5OToK3t7fw7LPPChUVFW38blrOI488IoSGhgoqlUro1KmTMH78eDGICIJ9XhNbaocRe70uM2fOFPz9/QWVSiUEBgYKM2fOtKqlYa/XRRAE4X//+58QEREhqNVqoWfPnsJnn31m9bg9/ty1RSYIgiBNnwwRERGRnc4ZISIiolsHwwgRERFJimGEiIiIJMUwQkRERJJiGCEiIiJJMYwQERGRpBhGiIiISFIMI0RERCQphhEiIiKSFMMIERERSYphhIiIiCTFMEJERESS+n+vc4gDPqKMswAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(torch.tensor(l).view(-1, 10).mean(1).numpy())" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "JQ23lyBOsiST", "outputId": "0e3c46a9-e58a-489f-cfb2-46e41d7c688f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "def gandas(self).s))).Streamed for i += num))\n", " \"\"\"\n", " batchraces.\n", " ...\n", " \"\"\"\n", " if self._jvm.SSL0 0.1:\n", " name = thod:\n", " \"\"\"\n", " ... (argitparam) for Jrbteast = df.tors.defaultBy recrient short \n" ] } ], "source": [ "# 使用模型来生成文本\n", "context = torch.zeros((1, 10), dtype=torch.long, device=device)\n", "print(''.join(tok.decode(generate(model, context))))" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "V100", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 1 }