{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ZTO7hf-zM-np", "outputId": "e8944765-badd-4a6c-c111-0ca802c1f23b" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.nn.utils import clip_grad_norm_\n", "from transformers import AutoTokenizer, AutoModelForCausalLM\n", "import torch.optim as optim\n", "from datasets import load_dataset\n", "from transformers import pipeline\n", "\n", "torch.manual_seed(12046)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "dJhQvyIYM-nr" }, "outputs": [], "source": [ "# 一些超参数\n", "learning_rate = 5e-5\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "gamma = 1.0\n", "lambda_ = 0.95\n", "kl_ctl_value = 0.2\n", "cliprange = 0.2\n", "vf_coef = 0.1\n", "# 经过mini_batch_size步后,更新旧模型\n", "mini_batch_size = 20\n", "grad_clip = 1.0" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "Bh967HO6M-ns" }, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained('gpt2')\n", "tokenizer.pad_token = tokenizer.eos_token" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "ZZWmZrRlM-ns" }, "outputs": [], "source": [ "def prepare_input(data):\n", " data['input_ids'] = [tokenizer.encode(data['text'])[:8]]\n", " return data\n", "\n", "datasets = load_dataset('imdb', split='train[:500]')\n", "datasets = datasets.filter(lambda x: len(x['text']) > 20)\n", "tokenized = datasets.map(prepare_input, remove_columns=datasets.column_names)\n", "tokenized.set_format(type='torch', device=device)\n", "example = tokenized[1]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "cH6wexbYM-nw", "scrolled": true }, "outputs": [], "source": [ "class A2CLLM(nn.Module):\n", "\n", " def __init__(self, model):\n", " super().__init__()\n", " self.actor = model\n", " # 值函数估计头\n", " self.critic = nn.Linear(model.base_model.embed_dim, 1, bias=False)\n", "\n", " def forward(self, x):\n", " '''\n", " 向前传播,为了使代码易懂,该函数只支持单条文本的计算\n", " 参数\n", " ----\n", " x :torch.LongTensor,文本,形状为(1, T)\n", " 返回\n", " ----\n", " logits :torch.FloatTensor,logits,形状为(1, T, vs)\n", " values :torch.FloatTensor,值函数,形状为(1, T)\n", " '''\n", " _res = self.actor(input_ids=x, output_hidden_states=True)\n", " logits = _res.logits\n", " emb = _res.hidden_states[-1]\n", " values = self.critic(emb).squeeze(-1)\n", " return logits, values\n", "\n", " def generate(self, idx, max_new_tokens=20):\n", " '''\n", " 生成文本\n", " '''\n", " model = self.actor\n", " return model.generate(idx, max_new_tokens=max_new_tokens,\n", " pad_token_id=tokenizer.eos_token_id)\n", "\n", "model = A2CLLM(AutoModelForCausalLM.from_pretrained('lvwerra/gpt2-imdb')).to(device)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "EO_cOnnCM-nw" }, "outputs": [], "source": [ "from peft import LoraConfig, PeftModel\n", "\n", "def init_peft_model(model):\n", " config = LoraConfig(\n", " r=1,\n", " lora_alpha=8,\n", " target_modules=['c_attn'],\n", " fan_in_fan_out=True,\n", " lora_dropout=0.1,\n", " bias='none',\n", " modules_to_save=['critic'])\n", " return PeftModel(model, config, adapter_name='lora_ppo')\n", "\n", "# 增加LoRA适配器\n", "model = init_peft_model(model)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "tv2HXFQwM-nw", "outputId": "fbd5f1e8-7cfd-4148-df83-1cc6059958b5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "logits torch.Size([1, 20, 50257])\n", "lnp torch.Size([1, 20])\n", "values torch.Size([1, 20])\n" ] } ], "source": [ "def get_forward_result(model, input_ids, response):\n", " '''\n", " 记录向前传播的结果,分别是logits,lnp和值函数\n", " 为了使代码易懂,该函数只支持单条文本的计算\n", " '''\n", " # 记录背景文本的长度\n", " _, lens = input_ids.shape\n", " logits, values = model(response)\n", " # 计算交叉熵的时候,需要注意logits和标签的对应关系\n", " lnp = -F.cross_entropy(logits[:, :-1, :].transpose(-2, -1), response[:, 1:], reduction='none')\n", " # 只记录针对生成文本的结果,其中L表示生成文本的长度\n", " res = {\n", " # 最后一个位置的logits没有作用\n", " 'logits': logits[:, lens-1:-1, :], # (1, L, vs)\n", " 'lnp': lnp[:, lens-1:], # (1, L)\n", " 'values': values[:, lens:] # (1, L)\n", " }\n", " return res\n", "\n", "\n", "input_ids = example['input_ids']\n", "response = model.generate(input_ids)\n", "\n", "# 验证get_forward_result计算结果的形状是准确的\n", "example_re = get_forward_result(model, input_ids, response)\n", "for k, v in example_re.items():\n", " print(k, v.shape)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Rkl8rTvBcOQX", "outputId": "01dadb70-33c5-4a79-8991-199b24d80d81" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[ 0.3356, -0.3501, -0.6011, -0.4132, 1.0261, 0.8811, -0.3165, 0.4929,\n", " -0.9196, -0.3321, -0.2723, -0.1996, -0.6541, 0.1892, 0.6956, 0.3488,\n", " 0.2956, 0.3583, 0.2754, 0.5844, 0.7313, 0.1374, 0.5127, -0.1030,\n", " 0.5666, -0.0081, 0.3219, -0.0353]], device='cuda:0',\n", " grad_fn=)\n", "tensor([[ 0.0418, 0.5579, 0.5273, 1.0549, 0.5402, 0.1473, 0.3205, 0.0311,\n", " 0.6900, -0.2323, 0.1526, 0.4450, 0.1746, 0.6160, -0.2214, -0.1989,\n", " 0.1022, 0.2701, -0.0173, -0.0539, -0.1477, 0.0678, -0.0153, -0.6429,\n", " -0.3822, -0.4266, -0.2184, -0.4352]], device='cuda:0',\n", " grad_fn=)\n", "tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0.]], device='cuda:0', grad_fn=)\n" ] } ], "source": [ "def turn_on_train_mode(model, target):\n", " '''\n", " 只将模型中的特定组件设置为训练模式\n", " '''\n", " for name, module in model.named_modules():\n", " if name.split('.')[-1] in target:\n", " module.train()\n", " return model\n", "\n", "def _test_turn_on_train_mode():\n", " '''\n", " 测试turn_on_train_mode是否正确\n", " '''\n", " test_model = A2CLLM(\n", " AutoModelForCausalLM.from_pretrained('lvwerra/gpt2-imdb')).to(device)\n", " config = LoraConfig(\n", " r=1,\n", " lora_alpha=8,\n", " target_modules=['c_attn'],\n", " fan_in_fan_out=True,\n", " lora_dropout=0.1,\n", " bias='none',\n", " init_lora_weights=False)\n", " test_model = PeftModel(test_model, config, adapter_name='lora_ppo')\n", " # 模型处于训练模式,由于随机失活的原因,每次运算的结果都不相同\n", " test_model.train()\n", " v1 = test_model(response)[1]\n", " v2 = test_model(response)[1]\n", " print(v1 - v2)\n", "\n", " test_model.eval()\n", " # 只将LoRA换至训练模式,由于LoRA里的随机失活,每次运算的结果也不相同\n", " turn_on_train_mode(test_model, ['c_attn'])\n", " v1 = test_model(response)[1]\n", " v2 = test_model(response)[1]\n", " print(v1 - v2)\n", "\n", " test_model.eval()\n", " turn_on_train_mode(test_model, ['c_attn'])\n", " # 禁用LoRA之后,运算结果会相同\n", " with test_model.disable_adapter():\n", " v1 = test_model(response)[1]\n", " v2 = test_model(response)[1]\n", " print(v1 - v2)\n", "\n", "_test_turn_on_train_mode()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "wW8OBSb6M-nx", "outputId": "a2faedc5-e036-4c3e-b29a-0f4acb924448" }, "outputs": [ { "data": { "text/plain": [ "tensor([0.9959])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class RewardModel(nn.Module):\n", "\n", " def __init__(self, tokenizer):\n", " '''\n", " 评分模型\n", " '''\n", " super().__init__()\n", " self.model = pipeline(\"sentiment-analysis\", model='lvwerra/distilbert-imdb')\n", " self.tokenizer = tokenizer\n", "\n", " def forward(self, x):\n", " '''\n", " 向前传播,为了使代码易懂,该函数只支持单条文本的计算\n", " 参数\n", " ----\n", " x :torch.LongTensor,文本,形状为(1, T)\n", " 返回\n", " ----\n", " re :torch.FloatTensor,评分,形状为(1)\n", " '''\n", " re = []\n", " x = [self.tokenizer.decode(i) for i in x]\n", " # 此处的x等于背景文本+生成文本,因此得到的scores稍有不妥\n", " # 更准确的做法是只对生成文本进行评分\n", " scores = self.model(x)\n", " for s in scores:\n", " # 将POSITIVE的概率视为评分\n", " if s['label'] == 'POSITIVE':\n", " re.append(s['score'])\n", " else:\n", " re.append(1 - s['score'])\n", " return torch.tensor(re)\n", "\n", "r_model = RewardModel(tokenizer).to(device)\n", "r_model(response)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1rGQcB0PM-nx", "outputId": "2dfc1902-72f5-43af-9234-b224ebd52959" }, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 20])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def compute_rewards(r_model, response, lnp, ref_lnp):\n", " '''\n", " 定义游戏奖励\n", " 为了使代码易懂,该函数只支持单条文本的计算\n", " '''\n", " # scores的形状为(1), lnp的形状为(1, L), ref_lnp的形状为(1, L)\n", " # r_model:评分模型,response:模型生成的回答\n", " # lnp:新/旧模型的概率对数,ref_lnp:参考模型的概率对数\n", " scores = r_model(response)\n", " rewards = []\n", " for score, lnprob, ref_lnprob in zip(scores, lnp, ref_lnp):\n", " kl = lnprob - ref_lnprob # ( L)\n", " # kl_ctl_value是调节KL penalty的系数,大于0\n", " reward = -kl_ctl_value * kl # ( L)\n", " # 游戏奖励等于模型评分 + KL penalty\n", " reward[-1] += score # ( L)\n", " rewards.append(reward)\n", " return torch.stack(rewards) # (1, L)\n", "\n", "# 得到参考模型的结果\n", "with torch.no_grad():\n", " with model.disable_adapter():\n", " ref_example_re = get_forward_result(model, input_ids, response)\n", "\n", "rewards = compute_rewards(r_model, response, example_re['lnp'], ref_example_re['lnp'])\n", "rewards.shape" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "huNDSgciM-nx" }, "outputs": [], "source": [ "class GAE:\n", "\n", " def __init__(self, gamma, lambda_):\n", " self.gamma = gamma\n", " self.lambda_ = lambda_\n", "\n", " def __call__(self, rewards, values):\n", " # 优势函数\n", " advantages = []\n", " last_advantage = 0\n", " vt_next = 0\n", " for r, vt in zip(reversed(rewards), reversed(values)):\n", " delta = r + self.gamma * vt_next - vt\n", " last_advantage = delta + self.gamma * self.lambda_ * last_advantage\n", " advantages.insert(0, last_advantage)\n", " vt_next = vt\n", "\n", " return torch.stack(advantages)\n", "\n", "gae = GAE(gamma, lambda_)\n", "advantages = gae(rewards, example_re['values'])" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "MU3Sz6iwM-ny", "outputId": "c104d182-1547-4596-8092-4d1d117cbe95" }, "outputs": [ { "data": { "text/plain": [ "tensor(-0.2746, device='cuda:0', grad_fn=)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def compute_loss(old_lnp, lnp, vpred, advantages):\n", " '''\n", " 定义模型损失\n", " 为了使代码易懂,该函数只支持单条文本的计算\n", " '''\n", " # old_lnp:旧模型的概率对数,形状为(1, L)\n", " # lnp:新/旧模型的概率对数,形状为(1, L)\n", " # vpred:值函数,形状为(1, L)\n", " # advantages:优势函数,形状为(1, L)\n", " # 值函数损失\n", " vf_loss = -advantages * vpred\n", " # 策略损失\n", " ratio = torch.exp(lnp - old_lnp)\n", " pg_losses = -advantages * ratio\n", " pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)\n", " pg_loss = torch.max(pg_losses, pg_losses2)\n", " # 整体损失\n", " loss = pg_loss.mean() + vf_coef * vf_loss.mean()\n", " return loss\n", "\n", "compute_loss(example_re['lnp'], example_re['lnp'], example_re['values'], advantages)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "v_yIwG4tM-nz", "outputId": "adb33a65-0b81-4955-f4fb-85be56dc33f3" }, "outputs": [ { "data": { "text/plain": [ "[tensor([[ 40, 26399, 314, 3001, 327, 47269, 20958, 12]],\n", " device='cuda:0'),\n", " tensor([[ 1, 40, 1703, 44269, 25, 12550, 1, 318]],\n", " device='cuda:0')]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def play_game(model, r_model, gae, data):\n", " model.eval()\n", " # 分别是背景文本,回复,向前传播结果和优势函数\n", " all_input_ids, all_response, all_res, all_advantages = [], [], [], []\n", " for input_ids in data['input_ids']:\n", " all_input_ids.append(input_ids)\n", " # 生成评论\n", " response = model.generate(input_ids)\n", " all_response.append(response)\n", " with torch.no_grad():\n", " # 记录旧模型数据\n", " res = get_forward_result(model, input_ids, response)\n", " all_res.append(res)\n", " # 记录参考模型数据\n", " with model.disable_adapter():\n", " ref_res = get_forward_result(model, input_ids, response)\n", " rewards = compute_rewards(r_model, response, res['lnp'], ref_res['lnp'])\n", " all_advantages.append(gae(rewards, res['values']))\n", " # 只将LoRA适配器切换至训练模式\n", " turn_on_train_mode(model, ['c_attn'])\n", " return all_input_ids, all_response, all_res, all_advantages\n", "\n", "play_game(model, r_model, gae, tokenized[:2])[0]" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "sPjg11nEM-nz", "outputId": "bcc0c350-c4bd-4485-85bc-2c718bc6a020" }, "outputs": [ { "data": { "text/plain": [ "{'score': 0.5244841426610947, 'ref_score': 0.5244841426610947}" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def estimate_rewards(r_model, model, all_input_ids):\n", " '''\n", " 预估模型评分\n", " '''\n", " re = {}\n", " # 将模型切换至评估模式\n", " model.eval()\n", " for input_ids in all_input_ids:\n", " # 生成文本\n", " response = model.generate(input_ids)\n", " # 记录评分\n", " re['score'] = re.get('score', 0) + r_model(response).item()\n", " # 记录参考模型的评分\n", " with model.disable_adapter():\n", " response = model.generate(input_ids)\n", " re['ref_score'] = re.get('ref_score', 0) + r_model(response).item()\n", " re['score'] /= len(all_input_ids)\n", " re['ref_score'] /= len(all_input_ids)\n", " # 只将LoRA适配器切换至训练模式\n", " turn_on_train_mode(model, ['c_attn'])\n", " return re\n", "\n", "estimate_rewards(r_model, model, tokenized[:20]['input_ids'])" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "mizmt8YrM-n0", "outputId": "f6614c34-3780-4655-821a-061daa007119" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "step 0: score 0.5412, ref_score 0.5085\n", "step 1: score 0.5412, ref_score 0.5085\n", "step 2: score 0.5085, ref_score 0.5085\n", "step 3: score 0.5412, ref_score 0.5085\n", "step 4: score 0.5180, ref_score 0.5085\n", "step 5: score 0.5182, ref_score 0.5085\n", "step 6: score 0.4743, ref_score 0.5085\n", "step 7: score 0.4743, ref_score 0.5085\n", "step 8: score 0.4741, ref_score 0.5085\n", "step 9: score 0.4741, ref_score 0.5085\n", "step 10: score 0.4725, ref_score 0.5085\n", "step 11: score 0.5210, ref_score 0.5085\n", "step 12: score 0.5225, ref_score 0.5085\n", "step 13: score 0.5168, ref_score 0.5085\n", "step 14: score 0.5184, ref_score 0.5085\n", "step 15: score 0.5135, ref_score 0.5085\n", "step 16: score 0.5147, ref_score 0.5085\n", "step 17: score 0.5129, ref_score 0.5085\n", "step 18: score 0.6062, ref_score 0.5085\n", "step 19: score 0.6182, ref_score 0.5085\n", "step 20: score 0.6737, ref_score 0.5085\n", "step 21: score 0.6730, ref_score 0.5085\n", "step 22: score 0.6731, ref_score 0.5085\n", "step 23: score 0.6724, ref_score 0.5085\n" ] } ], "source": [ "steps = datasets.num_rows // mini_batch_size\n", "optimizer = optim.AdamW(model.parameters(), lr=learning_rate)\n", "\n", "for s in range(steps-1):\n", " data = tokenized[s * mini_batch_size: (s + 1) * mini_batch_size]\n", " # 进行游戏,收集数据。play_game返回的数据都是无法计算梯度的\n", " # 在play_game中,会基于model生成参考模型\n", " input_ids, response, old_res, advantages = play_game(model, r_model, gae, data)\n", " # 循环完成之后,才用新模型替换旧模型\n", " for _ids, _resp, _old_res, _ad in zip(input_ids, response, old_res, advantages):\n", " optimizer.zero_grad(set_to_none=True)\n", " # 收集新模型的数据,model_res里面的数据可以计算梯度\n", " model_res = get_forward_result(model, _ids, _resp)\n", " loss = compute_loss(_old_res['lnp'], model_res['lnp'], model_res['values'], _ad)\n", " loss.backward()\n", " # 梯度裁剪\n", " clip_grad_norm_(model.parameters(), grad_clip)\n", " optimizer.step()\n", " # 将最后一个批次数据作为测试集\n", " res = estimate_rewards(r_model, model, tokenized[-mini_batch_size:]['input_ids'])\n", " print(f'step {s:>4}: score {res[\"score\"]:.4f}, ref_score {res[\"ref_score\"]:.4f}')" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "V100", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 1 }