modules.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # --------------------------------------------------------------------
  2. # Copyright (c) Meta Platforms, Inc. and affiliates.
  3. # All rights reserved.
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. # --------------------------------------------------------------------
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. from typing import Type, Tuple
  11. # ----------------------- Basic modules -----------------------
  12. class FeedFroward(nn.Module):
  13. def __init__(self,
  14. embedding_dim: int,
  15. mlp_dim: int,
  16. act: Type[nn.Module] = nn.GELU,
  17. dropout: float = 0.0,
  18. ) -> None:
  19. super().__init__()
  20. self.fc1 = nn.Linear(embedding_dim, mlp_dim)
  21. self.drop1 = nn.Dropout(dropout)
  22. self.fc2 = nn.Linear(mlp_dim, embedding_dim)
  23. self.drop2 = nn.Dropout(dropout)
  24. self.act = act()
  25. def forward(self, x: torch.Tensor) -> torch.Tensor:
  26. x = self.fc1(x)
  27. x = self.act(x)
  28. x = self.drop1(x)
  29. x = self.fc2(x)
  30. x = self.drop2(x)
  31. return x
  32. class PatchEmbed(nn.Module):
  33. def __init__(self,
  34. in_chans : int = 3,
  35. embed_dim : int = 768,
  36. kernel_size : int = 16,
  37. padding : int = 0,
  38. stride : int = 16,
  39. ) -> None:
  40. super().__init__()
  41. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
  42. def forward(self, x: torch.Tensor) -> torch.Tensor:
  43. return self.proj(x)
  44. # ----------------------- Model modules -----------------------
  45. class ViTBlock(nn.Module):
  46. def __init__(self,
  47. dim :int,
  48. num_heads :int,
  49. mlp_ratio :float = 4.0,
  50. qkv_bias :bool = True,
  51. act_layer :Type[nn.Module] = nn.GELU,
  52. dropout :float = 0.
  53. ) -> None:
  54. super().__init__()
  55. # -------------- Model parameters --------------
  56. self.norm1 = nn.LayerNorm(dim)
  57. self.attn = Attention(dim = dim,
  58. qkv_bias = qkv_bias,
  59. num_heads = num_heads,
  60. dropout = dropout
  61. )
  62. self.norm2 = nn.LayerNorm(dim)
  63. self.ffn = FeedFroward(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
  64. def forward(self, x: torch.Tensor) -> torch.Tensor:
  65. shortcut = x
  66. # Attention (with prenorm)
  67. x = self.norm1(x)
  68. x = self.attn(x)
  69. x = shortcut + x
  70. # Feedforward (with prenorm)
  71. x = x + self.ffn(self.norm2(x))
  72. return x
  73. class Attention(nn.Module):
  74. def __init__(self,
  75. dim :int,
  76. qkv_bias :bool = False,
  77. num_heads :int = 8,
  78. dropout :float = 0.
  79. ):
  80. super().__init__()
  81. # --------------- Basic parameters ---------------
  82. self.dim = dim
  83. self.num_heads = num_heads
  84. self.head_dim = dim // num_heads
  85. self.scale = self.head_dim ** -0.5
  86. # --------------- Network parameters ---------------
  87. self.qkv_proj = nn.Linear(dim, dim*3, bias = qkv_bias)
  88. self.attn_drop = nn.Dropout(dropout)
  89. self.proj = nn.Linear(dim, dim)
  90. self.proj_drop = nn.Dropout(dropout)
  91. def forward(self, x):
  92. bs, N, _ = x.shape
  93. # ----------------- Input proj -----------------
  94. qkv = self.qkv_proj(x)
  95. q, k, v = torch.chunk(qkv, 3, dim=-1)
  96. # ----------------- Multi-head Attn -----------------
  97. ## [B, N, C] -> [B, N, H, C_h] -> [B, H, N, C_h]
  98. q = q.view(bs, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
  99. k = k.view(bs, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
  100. v = v.view(bs, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
  101. ## [B, H, Nq, C_h] X [B, H, C_h, Nk] = [B, H, Nq, Nk]
  102. attn = q * self.scale @ k.transpose(-1, -2)
  103. attn = attn.softmax(dim=-1)
  104. attn = self.attn_drop(attn)
  105. x = attn @ v # [B, H, Nq, C_h]
  106. # ----------------- Output -----------------
  107. x = x.permute(0, 2, 1, 3).contiguous().view(bs, N, -1)
  108. x = self.proj(x)
  109. x = self.proj_drop(x)
  110. return x
  111. # ----------------------- Classifier -----------------------
  112. class AttentionPoolingClassifier(nn.Module):
  113. """
  114. This code is referenced to https://github.com/apple/ml-aim/blob/main/aim/torch/layers.py
  115. """
  116. def __init__(
  117. self,
  118. in_dim : int,
  119. out_dim : int,
  120. num_heads : int = 12,
  121. qkv_bias : bool = False,
  122. num_queries : int = 1,
  123. ):
  124. super().__init__()
  125. self.num_heads = num_heads
  126. head_dim = in_dim // num_heads
  127. self.scale = head_dim**-0.5
  128. self.k = nn.Linear(in_dim, in_dim, bias=qkv_bias)
  129. self.v = nn.Linear(in_dim, in_dim, bias=qkv_bias)
  130. self.cls_token = nn.Parameter(torch.randn(1, num_queries, in_dim) * 0.02)
  131. self.linear = nn.Linear(in_dim, out_dim)
  132. self.bn = nn.BatchNorm1d(in_dim, affine=False, eps=1e-6)
  133. self.num_queries = num_queries
  134. def forward(self, x: torch.Tensor):
  135. B, N, C = x.shape
  136. # Prenorm
  137. x = self.bn(x.transpose(-2, -1)).transpose(-2, -1)
  138. # [C] -> [B, 1, C]
  139. cls_token = self.cls_token.expand(B, -1, -1)
  140. # [B, 1, C] -> [B, 1, H, C_h] -> [B, H, 1, C_h]
  141. q = cls_token.reshape(
  142. B, self.num_queries, self.num_heads, C // self.num_heads
  143. ).permute(0, 2, 1, 3)
  144. # [B, N, C] -> [B, N, H, C_h] -> [B, H, N, C_h]
  145. k = self.k(x).reshape(
  146. B, N, self.num_heads, C // self.num_heads
  147. ).permute(0, 2, 1, 3)
  148. v = self.v(x).reshape(
  149. B, N, self.num_heads, C // self.num_heads
  150. ).permute(0, 2, 1, 3)
  151. # Attention
  152. q = q * self.scale
  153. attn = q @ k.transpose(-2, -1)
  154. attn = attn.softmax(dim=-1)
  155. x_cls = (attn @ v).transpose(1, 2).reshape(B, self.num_queries, C)
  156. x_cls = x_cls.mean(dim=1)
  157. # Classify
  158. out = self.linear(x_cls)
  159. return out, x_cls