|
|
@@ -839,7 +839,7 @@
|
|
|
" B, T = x.shape\n",
|
|
|
" pos = torch.arange(0, T, dtype=torch.long, device=x.device)\n",
|
|
|
" token_embeddings = self.token_emb(x) # (B, T, C)\n",
|
|
|
- " position_embeddings = self.pos_emb(pos) # (B, T, C)\n",
|
|
|
+ " position_embeddings = self.pos_emb(pos) # ( T, C)\n",
|
|
|
" h = token_embeddings + position_embeddings # (B, T, C)\n",
|
|
|
" h = self.blocks(h) # (B, T, C)\n",
|
|
|
" logits = self.lm(self.l(h)) # (B, T, vs)\n",
|