vit.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .modules import PatchEmbed, ViTBlock, AttentionPoolingClassifier
  5. except:
  6. from modules import PatchEmbed, ViTBlock, AttentionPoolingClassifier
  7. # ---------- Vision transformer ----------
  8. class ImageEncoderViT(nn.Module):
  9. def __init__(self,
  10. img_size: int,
  11. patch_size: int,
  12. in_chans: int,
  13. patch_embed_dim: int,
  14. depth: int,
  15. num_heads: int,
  16. mlp_ratio: float,
  17. act_layer: nn.GELU,
  18. dropout: float = 0.0,
  19. ) -> None:
  20. super().__init__()
  21. # ----------- Basic parameters -----------
  22. self.img_size = img_size
  23. self.patch_size = patch_size
  24. self.patch_embed_dim = patch_embed_dim
  25. self.num_heads = num_heads
  26. self.num_patches = (img_size // patch_size) ** 2
  27. # ----------- Model parameters -----------
  28. self.patch_embed = PatchEmbed(in_chans, patch_embed_dim, patch_size, stride=patch_size)
  29. self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, patch_embed_dim))
  30. self.norm_layer = nn.LayerNorm(patch_embed_dim)
  31. self.blocks = nn.ModuleList([
  32. ViTBlock(patch_embed_dim, num_heads, mlp_ratio, True, act_layer, dropout)
  33. for _ in range(depth)])
  34. self._init_weights()
  35. def _init_weights(self):
  36. # initialize (and freeze) pos_embed by sin-cos embedding
  37. pos_embed = self.get_posembed(self.pos_embed.shape[-1], int(self.num_patches**.5))
  38. self.pos_embed.data.copy_(pos_embed)
  39. # initialize nn.Linear and nn.LayerNorm
  40. for m in self.modules():
  41. if isinstance(m, nn.Linear):
  42. # we use xavier_uniform following official JAX ViT:
  43. torch.nn.init.xavier_uniform_(m.weight)
  44. if isinstance(m, nn.Linear) and m.bias is not None:
  45. nn.init.constant_(m.bias, 0)
  46. elif isinstance(m, nn.LayerNorm):
  47. nn.init.constant_(m.bias, 0)
  48. nn.init.constant_(m.weight, 1.0)
  49. def get_posembed(self, embed_dim, grid_size, temperature=10000):
  50. scale = 2 * torch.pi
  51. grid_h, grid_w = grid_size, grid_size
  52. num_pos_feats = embed_dim // 2
  53. # get grid
  54. y_embed, x_embed = torch.meshgrid([torch.arange(grid_h, dtype=torch.float32),
  55. torch.arange(grid_w, dtype=torch.float32)])
  56. # normalize grid coords
  57. y_embed = y_embed / (grid_h + 1e-6) * scale
  58. x_embed = x_embed / (grid_w + 1e-6) * scale
  59. dim_t = torch.arange(num_pos_feats, dtype=torch.float32)
  60. dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
  61. dim_t = temperature ** (2 * dim_t_)
  62. pos_x = torch.div(x_embed[..., None], dim_t)
  63. pos_y = torch.div(y_embed[..., None], dim_t)
  64. pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
  65. pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
  66. # [H, W, C] -> [N, C]
  67. pos_embed = torch.cat((pos_y, pos_x), dim=-1).view(-1, embed_dim)
  68. return pos_embed.unsqueeze(0)
  69. def forward(self, x: torch.Tensor) -> torch.Tensor:
  70. # Patch embed
  71. x = self.patch_embed(x)
  72. x = x.flatten(2).permute(0, 2, 1).contiguous()
  73. # Add pos embed
  74. x = x + self.pos_embed
  75. # Apply Transformer blocks
  76. for block in self.blocks:
  77. x = block(x)
  78. x = self.norm_layer(x)
  79. return x
  80. # ---------- Vision transformer for classification ----------
  81. class ViTForImageClassification(nn.Module):
  82. def __init__(self,
  83. image_encoder :ImageEncoderViT,
  84. num_classes :int = 1000,
  85. qkv_bias :bool = True,
  86. ):
  87. super().__init__()
  88. # -------- Model parameters --------
  89. self.encoder = image_encoder
  90. self.classifier = AttentionPoolingClassifier(image_encoder.patch_embed_dim,
  91. num_classes,
  92. image_encoder.num_heads,
  93. qkv_bias,
  94. num_queries=1)
  95. def forward(self, x):
  96. """
  97. Inputs:
  98. x: (torch.Tensor) -> [B, C, H, W]. Input image.
  99. """
  100. x = self.encoder(x)
  101. x, x_cls = self.classifier(x)
  102. return x
  103. if __name__=='__main__':
  104. import time
  105. # 构建ViT模型
  106. img_encoder = ImageEncoderViT(img_size=224,
  107. patch_size=16,
  108. in_chans=3,
  109. patch_embed_dim=192,
  110. depth=12,
  111. num_heads=3,
  112. mlp_ratio=4.0,
  113. act_layer=nn.GELU,
  114. dropout = 0.1)
  115. model = ViTForImageClassification(img_encoder, num_classes=10, qkv_bias=True)
  116. # 打印模型结构
  117. print(model)
  118. # 随即成生数据
  119. x = torch.randn(1, 3, 224, 224)
  120. # 模型前向推理
  121. t0 = time.time()
  122. output = model(x)
  123. t1 = time.time()
  124. print('Time: ', t1 - t0)