pos_embed.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. # --------------------------------------------------------
  6. # Position embedding utils
  7. # --------------------------------------------------------
  8. import numpy as np
  9. import torch
  10. # --------------------------------------------------------
  11. # 2D sine-cosine position embedding
  12. # References:
  13. # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
  14. # MoCo v3: https://github.com/facebookresearch/moco-v3
  15. # --------------------------------------------------------
  16. def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
  17. """
  18. grid_size: int of the grid height and width
  19. return:
  20. pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
  21. """
  22. grid_h = np.arange(grid_size, dtype=np.float32)
  23. grid_w = np.arange(grid_size, dtype=np.float32)
  24. grid = np.meshgrid(grid_w, grid_h) # here w goes first
  25. grid = np.stack(grid, axis=0)
  26. grid = grid.reshape([2, 1, grid_size, grid_size])
  27. pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
  28. if cls_token:
  29. pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
  30. return pos_embed
  31. def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
  32. assert embed_dim % 2 == 0
  33. # use half of dimensions to encode grid_h
  34. emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
  35. emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
  36. emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
  37. return emb
  38. def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
  39. """
  40. embed_dim: output dimension for each position
  41. pos: a list of positions to be encoded: size (M,)
  42. out: (M, D)
  43. """
  44. assert embed_dim % 2 == 0
  45. omega = np.arange(embed_dim // 2, dtype=np.float)
  46. omega /= embed_dim / 2.
  47. omega = 1. / 10000**omega # (D/2,)
  48. pos = pos.reshape(-1) # (M,)
  49. out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
  50. emb_sin = np.sin(out) # (M, D/2)
  51. emb_cos = np.cos(out) # (M, D/2)
  52. emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
  53. return emb
  54. # --------------------------------------------------------
  55. # Interpolate position embeddings for high-resolution
  56. # References:
  57. # DeiT: https://github.com/facebookresearch/deit
  58. # --------------------------------------------------------
  59. def interpolate_pos_embed(model, checkpoint_model):
  60. if 'pos_embed' in checkpoint_model:
  61. pos_embed_checkpoint = checkpoint_model['pos_embed']
  62. embedding_size = pos_embed_checkpoint.shape[-1]
  63. num_patches = model.num_patches
  64. num_extra_tokens = model.pos_embed.shape[-2] - num_patches
  65. # height (== width) for the checkpoint position embedding
  66. orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
  67. # height (== width) for the new position embedding
  68. new_size = int(num_patches ** 0.5)
  69. # class_token and dist_token are kept unchanged
  70. if orig_size != new_size:
  71. print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
  72. extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
  73. # only the position tokens are interpolated
  74. pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
  75. pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
  76. pos_tokens = torch.nn.functional.interpolate(
  77. pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
  78. pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
  79. new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
  80. checkpoint_model['pos_embed'] = new_pos_embed