vit.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # --------------------------------------------------------------------
  2. # Copyright (c) Meta Platforms, Inc. and affiliates.
  3. # All rights reserved.
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. # --------------------------------------------------------------------
  7. import torch
  8. import torch.nn as nn
  9. try:
  10. from .modules import PatchEmbed, ViTBlock
  11. except:
  12. from modules import PatchEmbed, ViTBlock
  13. # ---------------------- Vision transformer ----------------------
  14. class ImageEncoderViT(nn.Module):
  15. def __init__(self,
  16. img_size: int,
  17. patch_size: int,
  18. in_chans: int,
  19. patch_embed_dim: int,
  20. depth: int,
  21. num_heads: int,
  22. mlp_ratio: float,
  23. act_layer: nn.GELU,
  24. dropout: float = 0.0,
  25. ) -> None:
  26. super().__init__()
  27. # ----------- Basic parameters -----------
  28. self.img_size = img_size
  29. self.patch_size = patch_size
  30. self.image_embedding_size = img_size // ((patch_size if patch_size > 0 else 1))
  31. self.patch_embed_dim = patch_embed_dim
  32. self.num_heads = num_heads
  33. self.num_patches = (img_size // patch_size) ** 2
  34. # ----------- Model parameters -----------
  35. self.patch_embed = PatchEmbed(in_chans, patch_embed_dim, patch_size, stride=patch_size)
  36. self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, patch_embed_dim))
  37. self.norm_layer = nn.LayerNorm(patch_embed_dim)
  38. self.blocks = nn.ModuleList([
  39. ViTBlock(patch_embed_dim, num_heads, mlp_ratio, True, act_layer, dropout)
  40. for _ in range(depth)])
  41. self._init_weights()
  42. def _init_weights(self):
  43. # initialize (and freeze) pos_embed by sin-cos embedding
  44. pos_embed = self.get_posembed(self.pos_embed.shape[-1], int(self.num_patches**.5))
  45. self.pos_embed.data.copy_(pos_embed)
  46. # initialize nn.Linear and nn.LayerNorm
  47. for m in self.modules():
  48. if isinstance(m, nn.Linear):
  49. # we use xavier_uniform following official JAX ViT:
  50. torch.nn.init.xavier_uniform_(m.weight)
  51. if isinstance(m, nn.Linear) and m.bias is not None:
  52. nn.init.constant_(m.bias, 0)
  53. elif isinstance(m, nn.LayerNorm):
  54. nn.init.constant_(m.bias, 0)
  55. nn.init.constant_(m.weight, 1.0)
  56. def get_posembed(self, embed_dim, grid_size, temperature=10000):
  57. scale = 2 * torch.pi
  58. grid_h, grid_w = grid_size, grid_size
  59. num_pos_feats = embed_dim // 2
  60. # get grid
  61. y_embed, x_embed = torch.meshgrid([torch.arange(grid_h, dtype=torch.float32),
  62. torch.arange(grid_w, dtype=torch.float32)])
  63. # normalize grid coords
  64. y_embed = y_embed / (grid_h + 1e-6) * scale
  65. x_embed = x_embed / (grid_w + 1e-6) * scale
  66. dim_t = torch.arange(num_pos_feats, dtype=torch.float32)
  67. dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
  68. dim_t = temperature ** (2 * dim_t_)
  69. pos_x = torch.div(x_embed[..., None], dim_t)
  70. pos_y = torch.div(y_embed[..., None], dim_t)
  71. pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
  72. pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
  73. # [H, W, C] -> [N, C]
  74. pos_embed = torch.cat((pos_y, pos_x), dim=-1).view(-1, embed_dim)
  75. return pos_embed.unsqueeze(0)
  76. def forward(self, x: torch.Tensor) -> torch.Tensor:
  77. # Patch embed
  78. x = self.patch_embed(x)
  79. x = x.flatten(2).permute(0, 2, 1).contiguous()
  80. # Add pos embed
  81. x = x + self.pos_embed
  82. # Apply Transformer blocks
  83. for block in self.blocks:
  84. x = block(x)
  85. x = self.norm_layer(x)
  86. return x
  87. # ------------------------ Model Functions ------------------------
  88. def build_vit(model_name="vit_t", img_size=224, patch_size=16, img_dim=3):
  89. if model_name == "vit_t":
  90. return ImageEncoderViT(img_size=img_size,
  91. patch_size=patch_size,
  92. in_chans=img_dim,
  93. patch_embed_dim=192,
  94. depth=12,
  95. num_heads=3,
  96. mlp_ratio=4.0,
  97. act_layer=nn.GELU,
  98. dropout = 0.1)
  99. if model_name == "vit_s":
  100. return ImageEncoderViT(img_size=img_size,
  101. patch_size=patch_size,
  102. in_chans=img_dim,
  103. patch_embed_dim=384,
  104. depth=12,
  105. num_heads=6,
  106. mlp_ratio=4.0,
  107. act_layer=nn.GELU,
  108. dropout = 0.1)
  109. if model_name == "vit_b":
  110. return ImageEncoderViT(img_size=img_size,
  111. patch_size=patch_size,
  112. in_chans=img_dim,
  113. patch_embed_dim=768,
  114. depth=12,
  115. num_heads=12,
  116. mlp_ratio=4.0,
  117. act_layer=nn.GELU,
  118. dropout = 0.1)
  119. if model_name == "vit_l":
  120. return ImageEncoderViT(img_size=img_size,
  121. patch_size=patch_size,
  122. in_chans=img_dim,
  123. patch_embed_dim=1024,
  124. depth=24,
  125. num_heads=16,
  126. mlp_ratio=4.0,
  127. act_layer=nn.GELU,
  128. dropout = 0.1)
  129. if model_name == "vit_h":
  130. return ImageEncoderViT(img_size=img_size,
  131. patch_size=patch_size,
  132. in_chans=img_dim,
  133. patch_embed_dim=1280,
  134. depth=32,
  135. num_heads=16,
  136. mlp_ratio=4.0,
  137. act_layer=nn.GELU,
  138. dropout = 0.1)
  139. if __name__ == '__main__':
  140. import torch
  141. from thop import profile
  142. # Prepare an image as the input
  143. bs, c, h, w = 2, 3, 224, 224
  144. x = torch.randn(bs, c, h, w)
  145. patch_size = 16
  146. # Build model
  147. model = build_vit(patch_size=patch_size)
  148. # Inference
  149. outputs = model(x)
  150. # Compute FLOPs & Params
  151. print('==============================')
  152. model.eval()
  153. flops, params = profile(model, inputs=(x, ), verbose=False)
  154. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  155. print('Params : {:.2f} M'.format(params / 1e6))