{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'a': 0,\n", " 'b': 1,\n", " 'c': 2,\n", " 'd': 3,\n", " 'e': 4,\n", " 'f': 5,\n", " 'g': 6,\n", " 'h': 7,\n", " 'i': 8,\n", " 'j': 9,\n", " 'k': 10,\n", " 'l': 11,\n", " 'm': 12,\n", " 'n': 13,\n", " 'o': 14,\n", " 'p': 15,\n", " 'q': 16,\n", " 'r': 17,\n", " 's': 18,\n", " 't': 19,\n", " 'u': 20,\n", " 'v': 21,\n", " 'w': 22,\n", " 'x': 23,\n", " 'y': 24,\n", " 'z': 25}" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torch.nn.functional as F\n", "import string\n", "\n", "# 定义字典\n", "char2indx = {s: i for i, s in enumerate(sorted(string.ascii_lowercase))}\n", "char2indx" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['l', 'o', 'v', 'e']" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "example = list('love')\n", "example" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([11, 14, 21, 4]), torch.Size([4]))" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 利用字典,对文本进行数字化\n", "idx = []\n", "\n", "for i in example:\n", " idx.append(char2indx[i])\n", "\n", "idx = torch.tensor(idx)\n", "idx, idx.shape" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 1., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0.]]),\n", " torch.Size([4, 26]))" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 使用独热编码,将文本转换为二维张量\n", "num_claz = 26\n", "dims = 5\n", "x = F.one_hot(idx, num_classes=num_claz).float()\n", "x, x.shape" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-1.7867, -1.8944, 0.1891, -3.3317, 0.4883],\n", " [-1.3727, 1.1942, 0.1609, -1.8016, 0.3551],\n", " [ 0.0374, 0.9542, 0.1898, -0.4440, 1.4332],\n", " [-1.0798, 0.7559, 0.9129, 0.4616, -0.2050]])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 文本嵌入其实就是张量乘法\n", "x # ( 4, 26)\n", "W = torch.randn((num_claz, dims)) # (26, 5)\n", "x @ W # ( 4, 5)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[-1.7867, -1.8944, 0.1891, -3.3317, 0.4883],\n", " [-1.3727, 1.1942, 0.1609, -1.8016, 0.3551],\n", " [ 0.0374, 0.9542, 0.1898, -0.4440, 1.4332],\n", " [-1.0798, 0.7559, 0.9129, 0.4616, -0.2050]])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 与前面张量乘法一致,但更加友好的实现方式\n", "# 因为运算涉及的张量idx维度更少,而且不需要经过独热编码\n", "idx # ( 4)\n", "W # (26, 5)\n", "W[idx] # ( 4, 5)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# 文字嵌入的实现示例\n", "class Embedding:\n", " \n", " def __init__(self, num_embeddings, embedding_dim):\n", " self.weight = torch.randn((num_embeddings, embedding_dim))\n", "\n", " def __call__(self, idx):\n", " self.out = self.weight[idx]\n", " return self.out\n", "\n", " def parameters(self):\n", " return [self.weight]" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[-0.9700, 1.7496, -1.6055, 0.6170, -0.3594],\n", " [-1.3329, -0.3346, 0.6670, -0.2516, 0.6160],\n", " [-0.9252, 0.7330, 0.0849, -0.2643, 0.1934],\n", " [-0.2149, -0.4215, 1.2895, -0.6259, 0.9605]]),\n", " torch.Size([4]),\n", " torch.Size([4, 5]))" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 正确的使用方式\n", "emb = Embedding(num_claz, dims)\n", "emb(idx), idx.shape, emb(idx).shape" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([10, 11]), torch.Size([10, 11, 5]))" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 维度更多的例子\n", "# 可以将bidx理解成10个长度等于11的文本(文本的单元是字母)\n", "bidx = torch.randint(0, num_claz, (10, 11))\n", "bidx.shape, emb(bidx).shape" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[[ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [-0.4473, 1.5996, 1.8102, -1.1696, 0.2618],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100]],\n", "\n", " [[ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [-0.4473, 1.5996, 1.8102, -1.1696, 0.2618],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100]],\n", "\n", " [[ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [-0.4473, 1.5996, 1.8102, -1.1696, 0.2618],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100]],\n", "\n", " [[ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [-0.4473, 1.5996, 1.8102, -1.1696, 0.2618],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100],\n", " [ 0.5768, 0.0849, -1.4448, -1.1311, 0.3100]]])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 错误的使用方式\n", "# x是独热编码的结果\n", "emb(x.int())" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }