In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer, GPT2LMHeadModel, GPT2Model


torch.manual_seed(12046)

<torch._C.Generator at 0x7fc57cc63110>

In [None]:
learning_rate = 6e-4
sequence_len = 1024
batch_size = 8
gra_acc_steps = 8 * 2
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 64 * 2
eval_interval = 50

In [12]:
llm = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [11]:
class RewardModel(nn.Module):

    def __init__(self, model):
        super().__init__()
        self.embedding = model
        self.score = nn.Linear(model.embed_dim, 1, bias=False)

    def forward(self, x, seq_len=None):
        # x：表示文本，形状(B, T, vs)或者(B, T), seq_len：表示文本长度，形状(B)
        B = x.shape[0]
        T = x.shape[1]
        emb = self.get_last_hidden_state(x)     # (B, T, C)
        ind = torch.arange(B, device=x.device)
        if seq_len == None:
            seq_len = torch.tensor([T] * B)
        # 获取最后一个词元的特征
        pooled_emb = emb[ind, seq_len - 1]      # (B,    C)
        score = self.score(pooled_emb)          # (B,    1)
        return score
    
    def get_last_hidden_state(self, x):
        if len(x.shape) == 2:
            # x shape = (B, T)
            emb = self.embedding(x).last_hidden_state  # (B, T, C)
        # 为后面使用gumbel_softmax做准备，直接与embedding的模型参数进行计算
        else:
            # x shape = (B, T, vs)
            w = self.embedding.get_input_embeddings().weight  # (vs, C)
            inputs_embeds = x @ w  # (B, T, C)
            emb = self.embedding(inputs_embeds=inputs_embeds).last_hidden_state
        return emb

r_model = RewardModel(GPT2Model.from_pretrained('gpt2'))

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)