{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "48HjPXSxsiSO", "outputId": "1d36f06e-16a4-42bc-eba7-dc7d249589d7" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader\n", "from datasets import load_dataset\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "\n", "torch.manual_seed(12046)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "Hn9ypPW0siSP" }, "outputs": [], "source": [ "# 一些超参数\n", "context_length = 10\n", "learning_rate = 0.01\n", "eval_iters = 10\n", "batch_size=1000\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "P_-nzG89siSQ", "outputId": "4329e089-bae9-417b-e67f-7cef29c9d1d4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "def to_arrow_schema(schema):\n", " \"\"\" Convert a schema from Spark to Arrow\n", " \"\"\"\n", " import pyarrow as pa\n", " fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)\n", " for field in schema]\n", " return pa.schema(fields)\n", "['def to_arrow_schema(schema):\\n \"\"\" Convert a schema from Spark to Arrow\\n \"\"\"\\n import pyarrow as pa\\n fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)\\n for field in schema]\\n return pa.schema(fields)', 'def from_arrow_type(at):\\n \"\"\" Convert pyarrow type to Spark data type.\\n \"\"\"\\n import pyarrow.types as types\\n if types.is_boolean(at):\\n spark_type = BooleanType()\\n elif types.is_int8(at):\\n spark_type = ByteType()\\n elif types.is_int16(at):\\n spark_type = ShortType()\\n elif types.is_int32(at):\\n spark_type = IntegerType()\\n elif types.is_int64(at):\\n spark_type = LongType()\\n elif types.is_float32(at):\\n spark_type = FloatType()\\n elif types.is_float64(at):\\n spark_type = DoubleType()\\n elif types.is_decimal(at):\\n spark_type = DecimalType(precision=at.precision, scale=at.scale)\\n elif types.is_string(at):\\n spark_type = StringType()\\n elif types.is_binary(at):\\n spark_type = BinaryType()\\n elif types.is_date32(at):\\n spark_type = DateType()\\n elif types.is_timestamp(at):\\n spark_type = TimestampType()\\n elif types.is_list(at):\\n if types.is_timestamp(at.value_type):\\n raise TypeError(\"Unsupported type in conversion from Arrow: \" + str(at))\\n spark_type = ArrayType(from_arrow_type(at.value_type))\\n elif types.is_struct(at):\\n if any(types.is_struct(field.type) for field in at):\\n raise TypeError(\"Nested StructType not supported in conversion from Arrow: \" + str(at))\\n return StructType(\\n [StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)\\n for field in at])\\n else:\\n raise TypeError(\"Unsupported type in conversion from Arrow: \" + str(at))\\n return spark_type']\n" ] } ], "source": [ "raw_datasets = load_dataset('code_search_net', 'python')\n", "datasets = raw_datasets['train'].filter(lambda x: 'apache/spark' in x['repository_name'])\n", "# 通过索引提取datasets数据的时候,返回一个dict,其中的value是一个字符串\n", "print(datasets[8]['whole_func_string'])\n", "# 当传入的是一个数组时,返回的依然是一个dict,但其中的value是一个列表\n", "print(datasets[8: 10]['whole_func_string'])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "Ky6gh_4TsiSQ" }, "outputs": [], "source": [ "class char_tokenizer:\n", "\n", " def __init__(self, data, begin_ind=0, end_ind=1):\n", " # 数据中出现的所有字符构成字典\n", " chars = sorted(list(set(''.join(data))))\n", " # 预留两个位置给开头和结尾的特殊字符\n", " self.char2ind = {s : i + 2 for i, s in enumerate(chars)}\n", " self.char2ind['<|b|>'] = begin_ind\n", " self.char2ind['<|e|>'] = end_ind\n", " self.begin_ind = begin_ind\n", " self.end_ind = end_ind\n", " self.ind2char = {i : s for s, i in self.char2ind.items()}\n", "\n", " def encode(self, text):\n", " '''\n", " 编码\n", " 参数\n", " ----\n", " text :str,文本\n", " '''\n", " return [self.char2ind[c] for c in text]\n", "\n", " def decode(self, enc):\n", " '''\n", " 解码\n", " 参数\n", " ----\n", " enc :int or list[int]\n", " '''\n", " if isinstance(enc, int):\n", " return self.ind2char[enc]\n", " return [self.ind2char[i] for i in enc]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "yDS5elR7siSR", "outputId": "5c2e18d7-3f81-4873-a14d-c8a7b900a0d1" }, "outputs": [ { "data": { "text/plain": [ "('def postappend(self):', 99)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 举例验证分词器\n", "tok = char_tokenizer(datasets['whole_func_string'])\n", "example_text = 'def postappend(self):'\n", "''.join(tok.decode(tok.encode(example_text))), len(tok.char2ind)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "egLVxc1VsiSR", "outputId": "5d0b0760-2267-47d8-93f1-9bb326f25719" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<|b|><|b|><|b|><|b|><|b|><|b|><|b|><|b|><|b|><|b|> ---> d\n", "<|b|><|b|><|b|><|b|><|b|><|b|><|b|><|b|><|b|>d ---> e\n", "<|b|><|b|><|b|><|b|><|b|><|b|><|b|><|b|>de ---> f\n", "<|b|><|b|><|b|><|b|><|b|><|b|><|b|>def ---> \n", "<|b|><|b|><|b|><|b|><|b|><|b|>def ---> p\n", "<|b|><|b|><|b|><|b|><|b|>def p ---> o\n", "<|b|><|b|><|b|><|b|>def po ---> s\n", "<|b|><|b|><|b|>def pos ---> t\n", "<|b|><|b|>def post ---> a\n", "<|b|>def posta ---> p\n", "def postap ---> p\n", "ef postapp ---> e\n", "f postappe ---> n\n", " postappen ---> d\n", "postappend ---> (\n", "ostappend( ---> s\n", "stappend(s ---> e\n", "tappend(se ---> l\n", "append(sel ---> f\n", "ppend(self ---> )\n", "pend(self) ---> :\n", "end(self): ---> <|e|>\n" ] } ], "source": [ "def autoregressive_trans(text, tokenizer, context_length=context_length):\n", " '''\n", " 将文本转换成一系列的训练数据\n", " 参数\n", " ----\n", " text :str,文本\n", " tokenizer :分词器\n", " context_length :int,背景文本的长度\n", " 返回\n", " ----\n", " inputs :list[list[int]],背景文本(特征)\n", " labels :list[list[int]],预测标签\n", " '''\n", " inputs, labels = [], []\n", " b_ind = tokenizer.begin_ind\n", " e_ind = tokenizer.end_ind\n", " enc = tokenizer.encode(text)\n", " # 增加开始和结尾的特殊字符\n", " x = [b_ind] * context_length + enc + [e_ind]\n", " for i in range(len(x) - context_length):\n", " inputs.append(x[i: i + context_length])\n", " labels.append(x[i + context_length])\n", " return inputs, labels\n", "\n", "# 举例展示自回归模式的训练数据\n", "inputs, labels = autoregressive_trans(example_text, tok)\n", "for a, b in zip(inputs, labels):\n", " print(''.join(tok.decode(a)), '--->', tok.decode(b))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "7kOWxKCisiSR", "outputId": "82ab4276-dc85-42e2-c6a9-e2e99629928f" }, "outputs": [ { "data": { "text/plain": [ "{'inputs': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 0, 71],\n", " [0, 0, 0, 0, 0, 0, 0, 0, 71, 72],\n", " [0, 0, 0, 0, 0, 0, 0, 71, 72, 73],\n", " [0, 0, 0, 0, 0, 0, 71, 72, 73, 3],\n", " [0, 0, 0, 0, 0, 71, 72, 73, 3, 87],\n", " [0, 0, 0, 0, 71, 72, 73, 3, 87, 82],\n", " [0, 0, 0, 71, 72, 73, 3, 87, 82, 66],\n", " [0, 0, 71, 72, 73, 3, 87, 82, 66, 68],\n", " [0, 71, 72, 73, 3, 87, 82, 66, 68, 85],\n", " [71, 72, 73, 3, 87, 82, 66, 68, 85, 85],\n", " [72, 73, 3, 87, 82, 66, 68, 85, 85, 82],\n", " [73, 3, 87, 82, 66, 68, 85, 85, 82, 90],\n", " [3, 87, 82, 66, 68, 85, 85, 82, 90, 66],\n", " [87, 82, 66, 68, 85, 85, 82, 90, 66, 86],\n", " [82, 66, 68, 85, 85, 82, 90, 66, 86, 70],\n", " [66, 68, 85, 85, 82, 90, 66, 86, 70, 75],\n", " [68, 85, 85, 82, 90, 66, 86, 70, 75, 72],\n", " [85, 85, 82, 90, 66, 86, 70, 75, 72, 80],\n", " [85, 82, 90, 66, 86, 70, 75, 72, 80, 68],\n", " [82, 90, 66, 86, 70, 75, 72, 80, 68, 11],\n", " [90, 66, 86, 70, 75, 72, 80, 68, 11, 86],\n", " [66, 86, 70, 75, 72, 80, 68, 11, 86, 70],\n", " [86, 70, 75, 72, 80, 68, 11, 86, 70, 75],\n", " [70, 75, 72, 80, 68, 11, 86, 70, 75, 72],\n", " [75, 72, 80, 68, 11, 86, 70, 75, 72, 80],\n", " [72, 80, 68, 11, 86, 70, 75, 72, 80, 68],\n", " [80, 68, 11, 86, 70, 75, 72, 80, 68, 12],\n", " [68, 11, 86, 70, 75, 72, 80, 68, 12, 29],\n", " [11, 86, 70, 75, 72, 80, 68, 12, 29, 2],\n", " [86, 70, 75, 72, 80, 68, 12, 29, 2, 3],\n", " [70, 75, 72, 80, 68, 12, 29, 2, 3, 3],\n", " [75, 72, 80, 68, 12, 29, 2, 3, 3, 3],\n", " [72, 80, 68, 12, 29, 2, 3, 3, 3, 3],\n", " [80, 68, 12, 29, 2, 3, 3, 3, 3, 5],\n", " [68, 12, 29, 2, 3, 3, 3, 3, 5, 5],\n", " [12, 29, 2, 3, 3, 3, 3, 5, 5, 5],\n", " [29, 2, 3, 3, 3, 3, 5, 5, 5, 3],\n", " [2, 3, 3, 3, 3, 5, 5, 5, 3, 38],\n", " [3, 3, 3, 3, 5, 5, 5, 3, 38, 82],\n", " [3, 3, 3, 5, 5, 5, 3, 38, 82, 81],\n", " [3, 3, 5, 5, 5, 3, 38, 82, 81, 89],\n", " [3, 5, 5, 5, 3, 38, 82, 81, 89, 72],\n", " [5, 5, 5, 3, 38, 82, 81, 89, 72, 85],\n", " [5, 5, 3, 38, 82, 81, 89, 72, 85, 87],\n", " [5, 3, 38, 82, 81, 89, 72, 85, 87, 3],\n", " [3, 38, 82, 81, 89, 72, 85, 87, 3, 68],\n", " [38, 82, 81, 89, 72, 85, 87, 3, 68, 3],\n", " [82, 81, 89, 72, 85, 87, 3, 68, 3, 86],\n", " [81, 89, 72, 85, 87, 3, 68, 3, 86, 70],\n", " [89, 72, 85, 87, 3, 68, 3, 86, 70, 75],\n", " [72, 85, 87, 3, 68, 3, 86, 70, 75, 72],\n", " [85, 87, 3, 68, 3, 86, 70, 75, 72, 80],\n", " [87, 3, 68, 3, 86, 70, 75, 72, 80, 68],\n", " [3, 68, 3, 86, 70, 75, 72, 80, 68, 3],\n", " [68, 3, 86, 70, 75, 72, 80, 68, 3, 73],\n", " [3, 86, 70, 75, 72, 80, 68, 3, 73, 85],\n", " [86, 70, 75, 72, 80, 68, 3, 73, 85, 82],\n", " [70, 75, 72, 80, 68, 3, 73, 85, 82, 80],\n", " [75, 72, 80, 68, 3, 73, 85, 82, 80, 3],\n", " [72, 80, 68, 3, 73, 85, 82, 80, 3, 54],\n", " [80, 68, 3, 73, 85, 82, 80, 3, 54, 83],\n", " [68, 3, 73, 85, 82, 80, 3, 54, 83, 68],\n", " [3, 73, 85, 82, 80, 3, 54, 83, 68, 85],\n", " [73, 85, 82, 80, 3, 54, 83, 68, 85, 78],\n", " [85, 82, 80, 3, 54, 83, 68, 85, 78, 3],\n", " [82, 80, 3, 54, 83, 68, 85, 78, 3, 87],\n", " [80, 3, 54, 83, 68, 85, 78, 3, 87, 82],\n", " [3, 54, 83, 68, 85, 78, 3, 87, 82, 3],\n", " [54, 83, 68, 85, 78, 3, 87, 82, 3, 36],\n", " [83, 68, 85, 78, 3, 87, 82, 3, 36, 85],\n", " [68, 85, 78, 3, 87, 82, 3, 36, 85, 85],\n", " [85, 78, 3, 87, 82, 3, 36, 85, 85, 82],\n", " [78, 3, 87, 82, 3, 36, 85, 85, 82, 90],\n", " [3, 87, 82, 3, 36, 85, 85, 82, 90, 2],\n", " [87, 82, 3, 36, 85, 85, 82, 90, 2, 3],\n", " [82, 3, 36, 85, 85, 82, 90, 2, 3, 3],\n", " [3, 36, 85, 85, 82, 90, 2, 3, 3, 3],\n", " [36, 85, 85, 82, 90, 2, 3, 3, 3, 3],\n", " [85, 85, 82, 90, 2, 3, 3, 3, 3, 5],\n", " [85, 82, 90, 2, 3, 3, 3, 3, 5, 5],\n", " [82, 90, 2, 3, 3, 3, 3, 5, 5, 5],\n", " [90, 2, 3, 3, 3, 3, 5, 5, 5, 2],\n", " [2, 3, 3, 3, 3, 5, 5, 5, 2, 3],\n", " [3, 3, 3, 3, 5, 5, 5, 2, 3, 3],\n", " [3, 3, 3, 5, 5, 5, 2, 3, 3, 3],\n", " [3, 3, 5, 5, 5, 2, 3, 3, 3, 3],\n", " [3, 5, 5, 5, 2, 3, 3, 3, 3, 76],\n", " [5, 5, 5, 2, 3, 3, 3, 3, 76, 80],\n", " [5, 5, 2, 3, 3, 3, 3, 76, 80, 83],\n", " [5, 2, 3, 3, 3, 3, 76, 80, 83, 82],\n", " [2, 3, 3, 3, 3, 76, 80, 83, 82, 85],\n", " [3, 3, 3, 3, 76, 80, 83, 82, 85, 87],\n", " [3, 3, 3, 76, 80, 83, 82, 85, 87, 3],\n", " [3, 3, 76, 80, 83, 82, 85, 87, 3, 83],\n", " [3, 76, 80, 83, 82, 85, 87, 3, 83, 92],\n", " [76, 80, 83, 82, 85, 87, 3, 83, 92, 68],\n", " [80, 83, 82, 85, 87, 3, 83, 92, 68, 85],\n", " [83, 82, 85, 87, 3, 83, 92, 68, 85, 85],\n", " [82, 85, 87, 3, 83, 92, 68, 85, 85, 82],\n", " [85, 87, 3, 83, 92, 68, 85, 85, 82, 90],\n", " [87, 3, 83, 92, 68, 85, 85, 82, 90, 3],\n", " [3, 83, 92, 68, 85, 85, 82, 90, 3, 68],\n", " [83, 92, 68, 85, 85, 82, 90, 3, 68, 86],\n", " [92, 68, 85, 85, 82, 90, 3, 68, 86, 3],\n", " [68, 85, 85, 82, 90, 3, 68, 86, 3, 83],\n", " [85, 85, 82, 90, 3, 68, 86, 3, 83, 68],\n", " [85, 82, 90, 3, 68, 86, 3, 83, 68, 2],\n", " [82, 90, 3, 68, 86, 3, 83, 68, 2, 3],\n", " [90, 3, 68, 86, 3, 83, 68, 2, 3, 3],\n", " [3, 68, 86, 3, 83, 68, 2, 3, 3, 3],\n", " [68, 86, 3, 83, 68, 2, 3, 3, 3, 3],\n", " [86, 3, 83, 68, 2, 3, 3, 3, 3, 73],\n", " [3, 83, 68, 2, 3, 3, 3, 3, 73, 76],\n", " [83, 68, 2, 3, 3, 3, 3, 73, 76, 72],\n", " [68, 2, 3, 3, 3, 3, 73, 76, 72, 79],\n", " [2, 3, 3, 3, 3, 73, 76, 72, 79, 71],\n", " [3, 3, 3, 3, 73, 76, 72, 79, 71, 86],\n", " [3, 3, 3, 73, 76, 72, 79, 71, 86, 3],\n", " [3, 3, 73, 76, 72, 79, 71, 86, 3, 32],\n", " [3, 73, 76, 72, 79, 71, 86, 3, 32, 3],\n", " [73, 76, 72, 79, 71, 86, 3, 32, 3, 62],\n", " [76, 72, 79, 71, 86, 3, 32, 3, 62, 83],\n", " [72, 79, 71, 86, 3, 32, 3, 62, 83, 68],\n", " [79, 71, 86, 3, 32, 3, 62, 83, 68, 17],\n", " [71, 86, 3, 32, 3, 62, 83, 68, 17, 73],\n", " [86, 3, 32, 3, 62, 83, 68, 17, 73, 76],\n", " [3, 32, 3, 62, 83, 68, 17, 73, 76, 72],\n", " [32, 3, 62, 83, 68, 17, 73, 76, 72, 79],\n", " [3, 62, 83, 68, 17, 73, 76, 72, 79, 71],\n", " [62, 83, 68, 17, 73, 76, 72, 79, 71, 11],\n", " [83, 68, 17, 73, 76, 72, 79, 71, 11, 73],\n", " [68, 17, 73, 76, 72, 79, 71, 11, 73, 76],\n", " [17, 73, 76, 72, 79, 71, 11, 73, 76, 72],\n", " [73, 76, 72, 79, 71, 11, 73, 76, 72, 79],\n", " [76, 72, 79, 71, 11, 73, 76, 72, 79, 71],\n", " [72, 79, 71, 11, 73, 76, 72, 79, 71, 17],\n", " [79, 71, 11, 73, 76, 72, 79, 71, 17, 81],\n", " [71, 11, 73, 76, 72, 79, 71, 17, 81, 68],\n", " [11, 73, 76, 72, 79, 71, 17, 81, 68, 80],\n", " [73, 76, 72, 79, 71, 17, 81, 68, 80, 72],\n", " [76, 72, 79, 71, 17, 81, 68, 80, 72, 15],\n", " [72, 79, 71, 17, 81, 68, 80, 72, 15, 3],\n", " [79, 71, 17, 81, 68, 80, 72, 15, 3, 87],\n", " [71, 17, 81, 68, 80, 72, 15, 3, 87, 82],\n", " [17, 81, 68, 80, 72, 15, 3, 87, 82, 66],\n", " [81, 68, 80, 72, 15, 3, 87, 82, 66, 68],\n", " [68, 80, 72, 15, 3, 87, 82, 66, 68, 85],\n", " [80, 72, 15, 3, 87, 82, 66, 68, 85, 85],\n", " [72, 15, 3, 87, 82, 66, 68, 85, 85, 82],\n", " [15, 3, 87, 82, 66, 68, 85, 85, 82, 90],\n", " [3, 87, 82, 66, 68, 85, 85, 82, 90, 66],\n", " [87, 82, 66, 68, 85, 85, 82, 90, 66, 87],\n", " [82, 66, 68, 85, 85, 82, 90, 66, 87, 92],\n", " [66, 68, 85, 85, 82, 90, 66, 87, 92, 83],\n", " [68, 85, 85, 82, 90, 66, 87, 92, 83, 72],\n", " [85, 85, 82, 90, 66, 87, 92, 83, 72, 11],\n", " [85, 82, 90, 66, 87, 92, 83, 72, 11, 73],\n", " [82, 90, 66, 87, 92, 83, 72, 11, 73, 76],\n", " [90, 66, 87, 92, 83, 72, 11, 73, 76, 72],\n", " [66, 87, 92, 83, 72, 11, 73, 76, 72, 79],\n", " [87, 92, 83, 72, 11, 73, 76, 72, 79, 71],\n", " [92, 83, 72, 11, 73, 76, 72, 79, 71, 17],\n", " [83, 72, 11, 73, 76, 72, 79, 71, 17, 71],\n", " [72, 11, 73, 76, 72, 79, 71, 17, 71, 68],\n", " [11, 73, 76, 72, 79, 71, 17, 71, 68, 87],\n", " [73, 76, 72, 79, 71, 17, 71, 68, 87, 68],\n", " [76, 72, 79, 71, 17, 71, 68, 87, 68, 55],\n", " [72, 79, 71, 17, 71, 68, 87, 68, 55, 92],\n", " [79, 71, 17, 71, 68, 87, 68, 55, 92, 83],\n", " [71, 17, 71, 68, 87, 68, 55, 92, 83, 72],\n", " [17, 71, 68, 87, 68, 55, 92, 83, 72, 12],\n", " [71, 68, 87, 68, 55, 92, 83, 72, 12, 15],\n", " [68, 87, 68, 55, 92, 83, 72, 12, 15, 3],\n", " [87, 68, 55, 92, 83, 72, 12, 15, 3, 81],\n", " [68, 55, 92, 83, 72, 12, 15, 3, 81, 88],\n", " [55, 92, 83, 72, 12, 15, 3, 81, 88, 79],\n", " [92, 83, 72, 12, 15, 3, 81, 88, 79, 79],\n", " [83, 72, 12, 15, 3, 81, 88, 79, 79, 68],\n", " [72, 12, 15, 3, 81, 88, 79, 79, 68, 69],\n", " [12, 15, 3, 81, 88, 79, 79, 68, 69, 79],\n", " [15, 3, 81, 88, 79, 79, 68, 69, 79, 72],\n", " [3, 81, 88, 79, 79, 68, 69, 79, 72, 32],\n", " [81, 88, 79, 79, 68, 69, 79, 72, 32, 73],\n", " [88, 79, 79, 68, 69, 79, 72, 32, 73, 76],\n", " [79, 79, 68, 69, 79, 72, 32, 73, 76, 72],\n", " [79, 68, 69, 79, 72, 32, 73, 76, 72, 79],\n", " [68, 69, 79, 72, 32, 73, 76, 72, 79, 71],\n", " [69, 79, 72, 32, 73, 76, 72, 79, 71, 17],\n", " [79, 72, 32, 73, 76, 72, 79, 71, 17, 81],\n", " [72, 32, 73, 76, 72, 79, 71, 17, 81, 88],\n", " [32, 73, 76, 72, 79, 71, 17, 81, 88, 79],\n", " [73, 76, 72, 79, 71, 17, 81, 88, 79, 79],\n", " [76, 72, 79, 71, 17, 81, 88, 79, 79, 68],\n", " [72, 79, 71, 17, 81, 88, 79, 79, 68, 69],\n", " [79, 71, 17, 81, 88, 79, 79, 68, 69, 79],\n", " [71, 17, 81, 88, 79, 79, 68, 69, 79, 72],\n", " [17, 81, 88, 79, 79, 68, 69, 79, 72, 12],\n", " [81, 88, 79, 79, 68, 69, 79, 72, 12, 2],\n", " [88, 79, 79, 68, 69, 79, 72, 12, 2, 3],\n", " [79, 79, 68, 69, 79, 72, 12, 2, 3, 3],\n", " [79, 68, 69, 79, 72, 12, 2, 3, 3, 3],\n", " [68, 69, 79, 72, 12, 2, 3, 3, 3, 3],\n", " [69, 79, 72, 12, 2, 3, 3, 3, 3, 3],\n", " [79, 72, 12, 2, 3, 3, 3, 3, 3, 3],\n", " [72, 12, 2, 3, 3, 3, 3, 3, 3, 3],\n", " [12, 2, 3, 3, 3, 3, 3, 3, 3, 3],\n", " [2, 3, 3, 3, 3, 3, 3, 3, 3, 3],\n", " [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],\n", " [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],\n", " [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],\n", " [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],\n", " [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],\n", " [3, 3, 3, 3, 3, 3, 3, 3, 3, 73],\n", " [3, 3, 3, 3, 3, 3, 3, 3, 73, 82],\n", " [3, 3, 3, 3, 3, 3, 3, 73, 82, 85],\n", " [3, 3, 3, 3, 3, 3, 73, 82, 85, 3],\n", " [3, 3, 3, 3, 3, 73, 82, 85, 3, 73],\n", " [3, 3, 3, 3, 73, 82, 85, 3, 73, 76],\n", " [3, 3, 3, 73, 82, 85, 3, 73, 76, 72],\n", " [3, 3, 73, 82, 85, 3, 73, 76, 72, 79],\n", " [3, 73, 82, 85, 3, 73, 76, 72, 79, 71],\n", " [73, 82, 85, 3, 73, 76, 72, 79, 71, 3],\n", " [82, 85, 3, 73, 76, 72, 79, 71, 3, 76],\n", " [85, 3, 73, 76, 72, 79, 71, 3, 76, 81],\n", " [3, 73, 76, 72, 79, 71, 3, 76, 81, 3],\n", " [73, 76, 72, 79, 71, 3, 76, 81, 3, 86],\n", " [76, 72, 79, 71, 3, 76, 81, 3, 86, 70],\n", " [72, 79, 71, 3, 76, 81, 3, 86, 70, 75],\n", " [79, 71, 3, 76, 81, 3, 86, 70, 75, 72],\n", " [71, 3, 76, 81, 3, 86, 70, 75, 72, 80],\n", " [3, 76, 81, 3, 86, 70, 75, 72, 80, 68],\n", " [76, 81, 3, 86, 70, 75, 72, 80, 68, 64],\n", " [81, 3, 86, 70, 75, 72, 80, 68, 64, 2],\n", " [3, 86, 70, 75, 72, 80, 68, 64, 2, 3],\n", " [86, 70, 75, 72, 80, 68, 64, 2, 3, 3],\n", " [70, 75, 72, 80, 68, 64, 2, 3, 3, 3],\n", " [75, 72, 80, 68, 64, 2, 3, 3, 3, 3],\n", " [72, 80, 68, 64, 2, 3, 3, 3, 3, 85],\n", " [80, 68, 64, 2, 3, 3, 3, 3, 85, 72],\n", " [68, 64, 2, 3, 3, 3, 3, 85, 72, 87],\n", " [64, 2, 3, 3, 3, 3, 85, 72, 87, 88],\n", " [2, 3, 3, 3, 3, 85, 72, 87, 88, 85],\n", " [3, 3, 3, 3, 85, 72, 87, 88, 85, 81],\n", " [3, 3, 3, 85, 72, 87, 88, 85, 81, 3],\n", " [3, 3, 85, 72, 87, 88, 85, 81, 3, 83],\n", " [3, 85, 72, 87, 88, 85, 81, 3, 83, 68],\n", " [85, 72, 87, 88, 85, 81, 3, 83, 68, 17],\n", " [72, 87, 88, 85, 81, 3, 83, 68, 17, 86],\n", " [87, 88, 85, 81, 3, 83, 68, 17, 86, 70],\n", " [88, 85, 81, 3, 83, 68, 17, 86, 70, 75],\n", " [85, 81, 3, 83, 68, 17, 86, 70, 75, 72],\n", " [81, 3, 83, 68, 17, 86, 70, 75, 72, 80],\n", " [3, 83, 68, 17, 86, 70, 75, 72, 80, 68],\n", " [83, 68, 17, 86, 70, 75, 72, 80, 68, 11],\n", " [68, 17, 86, 70, 75, 72, 80, 68, 11, 73],\n", " [17, 86, 70, 75, 72, 80, 68, 11, 73, 76],\n", " [86, 70, 75, 72, 80, 68, 11, 73, 76, 72],\n", " [70, 75, 72, 80, 68, 11, 73, 76, 72, 79],\n", " [75, 72, 80, 68, 11, 73, 76, 72, 79, 71],\n", " [72, 80, 68, 11, 73, 76, 72, 79, 71, 86],\n", " [80, 68, 11, 73, 76, 72, 79, 71, 86, 12]],\n", " 'labels': [71,\n", " 72,\n", " 73,\n", " 3,\n", " 87,\n", " 82,\n", " 66,\n", " 68,\n", " 85,\n", " 85,\n", " 82,\n", " 90,\n", " 66,\n", " 86,\n", " 70,\n", " 75,\n", " 72,\n", " 80,\n", " 68,\n", " 11,\n", " 86,\n", " 70,\n", " 75,\n", " 72,\n", " 80,\n", " 68,\n", " 12,\n", " 29,\n", " 2,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 5,\n", " 5,\n", " 5,\n", " 3,\n", " 38,\n", " 82,\n", " 81,\n", " 89,\n", " 72,\n", " 85,\n", " 87,\n", " 3,\n", " 68,\n", " 3,\n", " 86,\n", " 70,\n", " 75,\n", " 72,\n", " 80,\n", " 68,\n", " 3,\n", " 73,\n", " 85,\n", " 82,\n", " 80,\n", " 3,\n", " 54,\n", " 83,\n", " 68,\n", " 85,\n", " 78,\n", " 3,\n", " 87,\n", " 82,\n", " 3,\n", " 36,\n", " 85,\n", " 85,\n", " 82,\n", " 90,\n", " 2,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 5,\n", " 5,\n", " 5,\n", " 2,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 76,\n", " 80,\n", " 83,\n", " 82,\n", " 85,\n", " 87,\n", " 3,\n", " 83,\n", " 92,\n", " 68,\n", " 85,\n", " 85,\n", " 82,\n", " 90,\n", " 3,\n", " 68,\n", " 86,\n", " 3,\n", " 83,\n", " 68,\n", " 2,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 73,\n", " 76,\n", " 72,\n", " 79,\n", " 71,\n", " 86,\n", " 3,\n", " 32,\n", " 3,\n", " 62,\n", " 83,\n", " 68,\n", " 17,\n", " 73,\n", " 76,\n", " 72,\n", " 79,\n", " 71,\n", " 11,\n", " 73,\n", " 76,\n", " 72,\n", " 79,\n", " 71,\n", " 17,\n", " 81,\n", " 68,\n", " 80,\n", " 72,\n", " 15,\n", " 3,\n", " 87,\n", " 82,\n", " 66,\n", " 68,\n", " 85,\n", " 85,\n", " 82,\n", " 90,\n", " 66,\n", " 87,\n", " 92,\n", " 83,\n", " 72,\n", " 11,\n", " 73,\n", " 76,\n", " 72,\n", " 79,\n", " 71,\n", " 17,\n", " 71,\n", " 68,\n", " 87,\n", " 68,\n", " 55,\n", " 92,\n", " 83,\n", " 72,\n", " 12,\n", " 15,\n", " 3,\n", " 81,\n", " 88,\n", " 79,\n", " 79,\n", " 68,\n", " 69,\n", " 79,\n", " 72,\n", " 32,\n", " 73,\n", " 76,\n", " 72,\n", " 79,\n", " 71,\n", " 17,\n", " 81,\n", " 88,\n", " 79,\n", " 79,\n", " 68,\n", " 69,\n", " 79,\n", " 72,\n", " 12,\n", " 2,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 73,\n", " 82,\n", " 85,\n", " 3,\n", " 73,\n", " 76,\n", " 72,\n", " 79,\n", " 71,\n", " 3,\n", " 76,\n", " 81,\n", " 3,\n", " 86,\n", " 70,\n", " 75,\n", " 72,\n", " 80,\n", " 68,\n", " 64,\n", " 2,\n", " 3,\n", " 3,\n", " 3,\n", " 3,\n", " 85,\n", " 72,\n", " 87,\n", " 88,\n", " 85,\n", " 81,\n", " 3,\n", " 83,\n", " 68,\n", " 17,\n", " 86,\n", " 70,\n", " 75,\n", " 72,\n", " 80,\n", " 68,\n", " 11,\n", " 73,\n", " 76,\n", " 72,\n", " 79,\n", " 71,\n", " 86,\n", " 12,\n", " 1]}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def process(data):\n", " '''\n", " 在datasets的map里使用,将文本转换成训练数据\n", " '''\n", " text = data['whole_func_string']\n", " # 如果是普通的map操作,传入的值是字符串\n", " if isinstance(text, str):\n", " inputs, labels = autoregressive_trans(text, tok)\n", " return {'inputs': inputs, 'labels': labels}\n", " # 如果是map操作里面batched=True,传入的值是字符串列表\n", " inputs, labels = [], []\n", " for i in text:\n", " i, l = autoregressive_trans(i, tok)\n", " inputs += i\n", " labels += l\n", " return {'inputs': inputs, 'labels': labels}\n", "\n", "process(datasets[8:9])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "U0Ojs8kbsiSS", "outputId": "0e571fc8-2432-4655-d276-538a79031000" }, "outputs": [ { "data": { "text/plain": [ "(torch.Size([645401, 10]), torch.Size([645401]))" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 将数据分为训练集和测试集\n", "tokenized = datasets.train_test_split(test_size=0.1, seed=1024, shuffle=True)\n", "# 将文本转换为训练数据,里面包含inputs和labels\n", "tokenized = tokenized.map(process, batched=True, remove_columns=datasets.column_names)\n", "tokenized.set_format(type='torch', device=device)\n", "\n", "tokenized['train']['inputs'].shape, tokenized['train']['labels'].shape" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "SdfNybxcsiSS", "outputId": "d8276b00-2799-4a1c-d0b5-41c9e515803f" }, "outputs": [ { "data": { "text/plain": [ "{'inputs': tensor([[38, 85, 72, ..., 12, 2, 3],\n", " [ 3, 76, 73, ..., 86, 87, 66],\n", " [80, 72, 15, ..., 75, 15, 3],\n", " ...,\n", " [75, 76, 86, ..., 79, 3, 68],\n", " [87, 68, 81, ..., 82, 90, 15],\n", " [ 3, 3, 3, ..., 68, 85, 68]], device='cuda:0'),\n", " 'labels': tensor([ 3, 86, 86, 90, 72, 87, 3, 81, 76, 29, 74, 87, 87, 87, 71, 3, 3, 3,\n", " 15, 3, 68, 3, 87, 75, 86, 17, 75, 13, 3, 80, 3, 83, 86, 68, 3, 83,\n", " 3, 81, 53, 83, 11, 66, 3, 85, 85, 82, 76, 85, 73, 14, 3, 75, 3, 3,\n", " 88, 3, 3, 3, 79, 82, 3, 50, 58, 11, 17, 3, 86, 93, 3, 87, 3, 74,\n", " 78, 78, 80, 54, 3, 2, 81, 80, 3, 81, 3, 3, 74, 3, 76, 82, 82, 3,\n", " 60, 12, 5, 73, 75, 72, 87, 3, 66, 73, 69, 85, 3, 76, 3, 3, 3, 86,\n", " 87, 72, 2, 3, 70, 85, 3, 3, 72, 32, 3, 3, 68, 3, 79, 72, 88, 3,\n", " 70, 76, 11, 3, 80, 72, 3, 87, 71, 3, 76, 73, 11, 82, 8, 84, 3, 72,\n", " 3, 76, 3, 49, 79, 87, 3, 3, 73, 12, 79, 90, 82, 83, 3, 71, 86, 91,\n", " 3, 75, 87, 83, 75, 72, 87, 76, 3, 3, 71, 92, 17, 87, 10, 87, 6, 70,\n", " 10, 3, 88, 3, 81, 86, 12, 72, 3, 3, 81, 72, 19, 68, 76, 44, 73, 82,\n", " 70, 82, 68, 3, 88, 86, 3, 5, 67, 72, 79, 70, 81, 3, 3, 75, 87, 3,\n", " 3, 76, 6, 3, 72, 70, 87, 3, 3, 87, 71, 3, 68, 2, 83, 85, 71, 87,\n", " 72, 33, 15, 87, 74, 3, 3, 29, 68, 3, 75, 85, 80, 3, 70, 72, 70, 82,\n", " 73, 82, 70, 86, 15, 3, 68, 5, 86, 82, 3, 3, 3, 3, 87, 3, 17, 76,\n", " 75, 87, 93, 3, 2, 81, 90, 3, 91, 72, 19, 3, 3, 69, 49, 3, 71, 81,\n", " 70, 5, 33, 3, 87, 3, 76, 20, 29, 3, 71, 68, 88, 86, 86, 17, 77, 82,\n", " 88, 3, 73, 3, 76, 3, 3, 81, 79, 72, 73, 82, 3, 82, 83, 72, 54, 20,\n", " 72, 72, 81, 2, 3, 80, 87, 85, 83, 80, 87, 72, 72, 76, 82, 3, 27, 81,\n", " 3, 3, 3, 86, 80, 66, 17, 76, 71, 80, 87, 17, 79, 83, 76, 83, 3, 3,\n", " 72, 76, 3, 87, 87, 46, 11, 68, 72, 3, 3, 70, 3, 81, 3, 2, 3, 68,\n", " 3, 72, 3, 68, 3, 72, 73, 3, 82, 83, 3, 32, 68, 2, 68, 73, 76, 11,\n", " 68, 68, 3, 3, 2, 3, 72, 17, 3, 55, 51, 3, 3, 87, 87, 3, 72, 87,\n", " 3, 3, 15, 89, 82, 3, 73, 81, 3, 85, 3, 74, 17, 75, 85, 3, 80, 79,\n", " 3, 85, 3, 3, 15, 88, 3, 12, 77, 83, 87, 86, 74, 72, 91, 3, 3, 79,\n", " 71, 29, 72, 3, 72, 3, 68, 81, 73, 3, 71, 74, 86, 3, 87, 3, 79, 17,\n", " 85, 76, 3, 12, 85, 72, 72, 5, 3, 3, 90, 3, 17, 3, 3, 72, 86, 80,\n", " 3, 82, 86, 68, 72, 3, 69, 71, 86, 85, 85, 3, 77, 10, 71, 76, 3, 3,\n", " 3, 85, 87, 5, 76, 81, 3, 3, 90, 68, 71, 81, 85, 80, 3, 79, 68, 17,\n", " 79, 3, 86, 3, 70, 73, 1, 3, 3, 3, 3, 87, 86, 3, 86, 68, 3, 87,\n", " 72, 3, 2, 86, 72, 87, 79, 71, 81, 15, 86, 85, 87, 78, 86, 74, 3, 66,\n", " 79, 82, 76, 80, 82, 2, 79, 70, 74, 51, 68, 66, 3, 3, 79, 3, 71, 2,\n", " 86, 85, 3, 68, 3, 5, 39, 71, 72, 81, 87, 3, 3, 3, 3, 3, 3, 72,\n", " 3, 72, 88, 81, 3, 85, 82, 68, 39, 3, 72, 68, 2, 3, 68, 70, 80, 69,\n", " 72, 2, 49, 82, 82, 81, 90, 3, 76, 2, 66, 72, 3, 3, 3, 3, 88, 72,\n", " 3, 19, 3, 72, 81, 3, 76, 78, 3, 3, 76, 76, 89, 81, 80, 87, 3, 81,\n", " 68, 81, 79, 2, 11, 71, 3, 81, 72, 49, 87, 72, 76, 5, 73, 68, 3, 3,\n", " 79, 72, 71, 21, 3, 85, 17, 3, 3, 3, 79, 76, 3, 3, 68, 82, 3, 87,\n", " 81, 3, 87, 72, 3, 87, 21, 44, 3, 87, 68, 81, 3, 68, 82, 87, 3, 70,\n", " 3, 12, 72, 87, 85, 2, 3, 3, 75, 82, 50, 3, 79, 85, 68, 64, 3, 62,\n", " 8, 11, 79, 72, 3, 76, 85, 15, 33, 32, 68, 3, 5, 3, 76, 21, 88, 79,\n", " 49, 68, 3, 88, 51, 3, 3, 80, 68, 79, 3, 3, 2, 32, 3, 12, 2, 68,\n", " 3, 81, 17, 87, 71, 3, 3, 76, 76, 92, 3, 88, 3, 3, 49, 80, 3, 85,\n", " 3, 3, 19, 79, 49, 87, 3, 3, 76, 80, 3, 68, 2, 71, 11, 3, 13, 3,\n", " 72, 29, 76, 86, 82, 3, 29, 32, 3, 10, 81, 3, 3, 3, 87, 85, 17, 85,\n", " 10, 11, 85, 78, 3, 17, 78, 70, 21, 3, 83, 3, 3, 90, 3, 85, 3, 82,\n", " 71, 88, 3, 63, 80, 89, 54, 69, 11, 15, 3, 81, 3, 69, 72, 68, 3, 88,\n", " 42, 68, 3, 38, 19, 12, 41, 76, 87, 11, 3, 87, 3, 76, 79, 72, 82, 3,\n", " 70, 2, 2, 67, 3, 15, 80, 87, 2, 54, 2, 11, 11, 3, 70, 87, 87, 70,\n", " 75, 12, 11, 82, 87, 2, 81, 70, 70, 21, 82, 81, 3, 72, 86, 72, 2, 47,\n", " 87, 3, 89, 78, 29, 85, 68, 12, 86, 82, 49, 3, 3, 81, 3, 73, 76, 83,\n", " 3, 68, 79, 3, 72, 3, 3, 3, 72, 3, 83, 63, 3, 86, 82, 73, 71, 72,\n", " 54, 72, 3, 17, 92, 73, 3, 86, 3, 82, 15, 3, 72, 74, 3, 72, 3, 17,\n", " 3, 72, 81, 75, 72, 3, 3, 3, 87, 37, 75, 73, 3, 5, 86, 2, 5, 81,\n", " 68, 3, 15, 2, 85, 85, 3, 3, 3, 72, 3, 80, 72, 68, 3, 85, 11, 83,\n", " 62, 72, 3, 81, 86, 2, 75, 79, 3, 80], device='cuda:0')}" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 构建数据读取器\n", "train_loader = DataLoader(tokenized['train'], batch_size=batch_size, shuffle=True)\n", "test_loader = DataLoader(tokenized['test'], batch_size=batch_size, shuffle=True)\n", "# 获取一个批量的数据\n", "next(iter(test_loader))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "WVXSXA6vsiSS" }, "outputs": [], "source": [ "class CharMLP(nn.Module):\n", "\n", " def __init__(self, vs):\n", " '''\n", " 根据文本背景预测下一个字母是什么\n", " 参数\n", " ----\n", " vs :int,字典大小\n", " '''\n", " super().__init__()\n", " # 文字嵌入层\n", " self.embedding = nn.Embedding(vs, 30)\n", " self.hidden1 = nn.Linear(300, 200)\n", " self.hidden2 = nn.Linear(200, 100)\n", " self.out = nn.Linear(100, vs)\n", "\n", " def forward(self, x):\n", " '''\n", " 向前传播\n", " 参数\n", " ----\n", " x :torch.LongTensor,背景文本,其中的元素表示相应位置的字母在字典中的位置\n", " 返回\n", " ----\n", " h :torch.FloatTensor,预测结果的logits\n", " '''\n", " # 因为背景文本的长度(context_length)等于10,\n", " # 所以x的形状是(B, 10),B表示批量数据的大小\n", " B = x.shape[0] # (B, 10)\n", " emb = self.embedding(x) # (B, 10, 30)\n", " h = emb.view(B, -1) # (B, 300)\n", " h = F.relu(self.hidden1(h)) # (B, 200)\n", " h = F.relu(self.hidden2(h)) # (B, 100)\n", " h = self.out(h) # (B, vs)\n", " return h\n", "\n", "model = CharMLP(len(tok.char2ind)).to(device)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "82S3F9IXsiSS" }, "outputs": [], "source": [ "@torch.no_grad()\n", "def generate(model, context, max_new_tokens=300):\n", " '''\n", " 利用模型生成文本(反复使用模型进行预测)\n", " 参数\n", " ----\n", " model :CharMLP,生成文本的模型\n", " context :torch.LongTensor,背景文本,形状为(1, 10)\n", " max_new_tokens :int,生成文本的最大长度\n", " 返回\n", " ----\n", " out :list[int],生成的文本\n", " '''\n", " out = []\n", " # 将模型切换至评估模式\n", " model.eval()\n", " for _ in range(max_new_tokens):\n", " logits = model(context)\n", " probs = F.softmax(logits, dim=-1)\n", " # 根据模型预测的概率,得到最终的预测结果(下一个字母)\n", " # 这一步运算有一定随机性\n", " ix = torch.multinomial(probs, num_samples=1)\n", " # 利用模型的预测结果更新文本背景\n", " context = torch.cat((context[:, 1:], ix), dim=1)\n", " out.append(ix.item())\n", " if ix.item() == tok.end_ind:\n", " break\n", " # 将模型切换至训练模式\n", " model.train()\n", " return out" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "4zSHFccTsiST", "outputId": "e0a6ca3f-fe74-49c6-caf1-0080eef4b158" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ")YN'E.ne'!XOzAYD\n", "F{tvö290&^#>P8(MZzJP\n" ] } ], "source": [ "# 使用模型来生成文本\n", "context = torch.zeros((1, 10), dtype=torch.long, device=device)\n", "print(''.join(tok.decode(generate(model, context))))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "bo3UzjmXsiST", "outputId": "37337a6e-a3c0-40aa-84cb-6b88ec257da7" }, "outputs": [ { "data": { "text/plain": [ "{'train': 4.5956830978393555, 'test': 4.594418525695801}" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def estimate_loss(model):\n", " re = {}\n", " # 将模型切换至评估模式\n", " model.eval()\n", " re['train'] = _loss(model, train_loader)\n", " re['test'] = _loss(model, test_loader)\n", " # 将模型切换至训练模式\n", " model.train()\n", " return re\n", "\n", "@torch.no_grad()\n", "def _loss(model, data_loader):\n", " \"\"\"\n", " 计算模型在不同数据集下面的评估指标\n", " \"\"\"\n", " loss = []\n", " data_iter= iter(data_loader)\n", " # 随机使用多个批量数据来预估模型效果\n", " for k in range(eval_iters):\n", " data = next(data_iter, None)\n", " if data is None:\n", " data_iter = iter(data_loader)\n", " data = next(data_iter, None)\n", " inputs, labels = data['inputs'], data['labels']\n", " logits = model(inputs)\n", " loss.append(F.cross_entropy(logits, labels).item())\n", " return torch.tensor(loss).mean().item()\n", "\n", "estimate_loss(model)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "18nkKzWLsiST" }, "outputs": [], "source": [ "def train_mlp(model, optimizer, data_loader, epochs=10):\n", " # 记录模型在训练集上的模型损失\n", " lossi = []\n", " for epoch in range(epochs):\n", " for i, data in enumerate(data_loader, 0):\n", " inputs, labels = data['inputs'], data['labels']\n", " optimizer.zero_grad()\n", " logits = model(inputs)\n", " loss = F.cross_entropy(logits, labels)\n", " lossi.append(loss.item())\n", " loss.backward()\n", " optimizer.step()\n", " # 评估模型,并输出结果\n", " stats = estimate_loss(model)\n", " train_loss = f'train loss {stats[\"train\"]:.4f}'\n", " test_loss = f'test loss {stats[\"test\"]:.4f}'\n", " print(f'epoch {epoch:>2}: {train_loss}, {test_loss}')\n", " return lossi" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "n_af5vmnsiST", "outputId": "80a2167d-ff62-4261-b7c8-4e63f2b7287f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0: train loss 1.3726, test loss 1.5097\n", "epoch 1: train loss 1.2598, test loss 1.4965\n", "epoch 2: train loss 1.1934, test loss 1.4247\n", "epoch 3: train loss 1.1630, test loss 1.4014\n", "epoch 4: train loss 1.1505, test loss 1.3658\n", "epoch 5: train loss 1.1539, test loss 1.3594\n", "epoch 6: train loss 1.0862, test loss 1.3975\n", "epoch 7: train loss 1.0872, test loss 1.3718\n", "epoch 8: train loss 1.0707, test loss 1.3832\n", "epoch 9: train loss 1.0845, test loss 1.3286\n" ] } ], "source": [ "l = train_mlp(model, optim.Adam(model.parameters(), lr=learning_rate), train_loader)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 448 }, "id": "a4tA30XfiMFv", "outputId": "d8065926-15f2-4f0e-e3dd-2a5e3ef16768" }, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(torch.tensor(l).view(-1, 10).mean(1).numpy())" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "JQ23lyBOsiST", "outputId": "0e3c46a9-e58a-489f-cfb2-46e41d7c688f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "def gandas(self).s))).Streamed for i += num))\n", " \"\"\"\n", " batchraces.\n", " ...\n", " \"\"\"\n", " if self._jvm.SSL0 0.1:\n", " name = thod:\n", " \"\"\"\n", " ... (argitparam) for Jrbteast = df.tors.defaultBy recrient short \n" ] } ], "source": [ "# 使用模型来生成文本\n", "context = torch.zeros((1, 10), dtype=torch.long, device=device)\n", "print(''.join(tok.decode(generate(model, context))))" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "V100", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 1 }