vit_mae.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. try:
  5. from .modules import ViTBlock, PatchEmbed
  6. except:
  7. from modules import ViTBlock, PatchEmbed
  8. # ------------------------ Basic Modules ------------------------
  9. class MaeEncoder(nn.Module):
  10. def __init__(self,
  11. img_size: int,
  12. patch_size: int,
  13. in_chans: int,
  14. patch_embed_dim: int,
  15. depth: int,
  16. num_heads: int,
  17. mlp_ratio: float,
  18. act_layer: nn.GELU,
  19. mask_ratio: float = 0.75,
  20. dropout: float = 0.0,
  21. ) -> None:
  22. super().__init__()
  23. # ----------- Basic parameters -----------
  24. self.img_size = img_size
  25. self.patch_size = patch_size
  26. self.image_embedding_size = img_size // ((patch_size if patch_size > 0 else 1))
  27. self.patch_embed_dim = patch_embed_dim
  28. self.num_heads = num_heads
  29. self.num_patches = (img_size // patch_size) ** 2
  30. self.mask_ratio = mask_ratio
  31. # ----------- Model parameters -----------
  32. self.patch_embed = PatchEmbed(in_chans, patch_embed_dim, patch_size, 0, patch_size)
  33. self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, patch_embed_dim), requires_grad=False)
  34. self.norm_layer = nn.LayerNorm(patch_embed_dim)
  35. self.blocks = nn.ModuleList([
  36. ViTBlock(patch_embed_dim, num_heads, mlp_ratio, True, act_layer=act_layer, dropout=dropout)
  37. for _ in range(depth)])
  38. self._init_weights()
  39. def _init_weights(self):
  40. # initialize (and freeze) pos_embed by sin-cos embedding
  41. pos_embed = self.get_posembed(self.pos_embed.shape[-1], int(self.num_patches**.5))
  42. self.pos_embed.data.copy_(pos_embed)
  43. # initialize nn.Linear and nn.LayerNorm
  44. for m in self.modules():
  45. if isinstance(m, nn.Linear):
  46. # we use xavier_uniform following official JAX ViT:
  47. torch.nn.init.xavier_uniform_(m.weight)
  48. if isinstance(m, nn.Linear) and m.bias is not None:
  49. nn.init.constant_(m.bias, 0)
  50. elif isinstance(m, nn.LayerNorm):
  51. nn.init.constant_(m.bias, 0)
  52. nn.init.constant_(m.weight, 1.0)
  53. def get_posembed(self, embed_dim, grid_size, temperature=10000):
  54. scale = 2 * math.pi
  55. grid_h, grid_w = grid_size, grid_size
  56. num_pos_feats = embed_dim // 2
  57. # get grid
  58. y_embed, x_embed = torch.meshgrid([torch.arange(grid_h, dtype=torch.float32),
  59. torch.arange(grid_w, dtype=torch.float32)])
  60. # normalize grid coords
  61. y_embed = y_embed / (grid_h + 1e-6) * scale
  62. x_embed = x_embed / (grid_w + 1e-6) * scale
  63. dim_t = torch.arange(num_pos_feats, dtype=torch.float32)
  64. dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
  65. dim_t = temperature ** (2 * dim_t_)
  66. pos_x = torch.div(x_embed[..., None], dim_t)
  67. pos_y = torch.div(y_embed[..., None], dim_t)
  68. pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
  69. pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
  70. # [H, W, C] -> [N, C]
  71. pos_embed = torch.cat((pos_y, pos_x), dim=-1).view(-1, embed_dim)
  72. return pos_embed.unsqueeze(0)
  73. def random_masking(self, x):
  74. B, N, C = x.shape
  75. len_keep = int(N * (1 - self.mask_ratio))
  76. noise = torch.rand(B, N, device=x.device) # noise in [0, 1]
  77. # sort noise for each sample
  78. ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
  79. ids_restore = torch.argsort(ids_shuffle, dim=1) # restore the original position of each patch
  80. # keep the first subset
  81. ids_keep = ids_shuffle[:, :len_keep]
  82. x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, C))
  83. # generate the binary mask: 0 is keep, 1 is remove
  84. mask = torch.ones([B, N], device=x.device)
  85. mask[:, :len_keep] = 0
  86. # unshuffle to get th binary mask
  87. mask = torch.gather(mask, dim=1, index=ids_restore)
  88. return x_masked, mask, ids_restore
  89. def forward(self, x: torch.Tensor) -> torch.Tensor:
  90. # patch embed
  91. x = self.patch_embed(x)
  92. # [B, C, H, W] -> [B, C, N] -> [B, N, C], N = H x W
  93. x = x.flatten(2).permute(0, 2, 1).contiguous()
  94. # add pos embed
  95. x = x + self.pos_embed
  96. # masking: length -> length * mask_ratio
  97. x, mask, ids_restore = self.random_masking(x)
  98. # apply Transformer blocks
  99. for block in self.blocks:
  100. x = block(x)
  101. x = self.norm_layer(x)
  102. return x, mask, ids_restore
  103. class MaeDecoder(nn.Module):
  104. def __init__(self,
  105. img_dim :int = 3,
  106. img_size :int = 16,
  107. patch_size :int = 16,
  108. en_emb_dim :int = 784,
  109. de_emb_dim :int = 512,
  110. de_num_layers :int = 12,
  111. de_num_heads :int = 12,
  112. qkv_bias :bool = True,
  113. mlp_ratio :float = 4.0,
  114. dropout :float = 0.1,
  115. mask_ratio :float = 0.75,
  116. ):
  117. super().__init__()
  118. # -------- basic parameters --------
  119. self.img_size = img_size
  120. self.patch_size = patch_size
  121. self.num_patches = (img_size // patch_size) ** 2
  122. self.en_emb_dim = en_emb_dim
  123. self.de_emb_dim = de_emb_dim
  124. self.de_num_layers = de_num_layers
  125. self.de_num_heads = de_num_heads
  126. self.mask_ratio = mask_ratio
  127. # -------- network parameters --------
  128. self.mask_token = nn.Parameter(torch.zeros(1, 1, de_emb_dim))
  129. self.decoder_embed = nn.Linear(en_emb_dim, de_emb_dim)
  130. self.mask_token = nn.Parameter(torch.zeros(1, 1, de_emb_dim))
  131. self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, de_emb_dim), requires_grad=False) # fixed sin-cos embedding
  132. self.decoder_norm = nn.LayerNorm(de_emb_dim)
  133. self.decoder_pred = nn.Linear(de_emb_dim, patch_size**2 * img_dim, bias=True)
  134. self.blocks = nn.ModuleList([
  135. ViTBlock(de_emb_dim, de_num_heads, mlp_ratio, qkv_bias, dropout=dropout)
  136. for _ in range(de_num_layers)])
  137. self._init_weights()
  138. def _init_weights(self):
  139. # initialize (and freeze) pos_embed by sin-cos embedding
  140. decoder_pos_embed = self.get_posembed(self.decoder_pos_embed.shape[-1], int(self.num_patches**.5))
  141. self.decoder_pos_embed.data.copy_(decoder_pos_embed)
  142. # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
  143. torch.nn.init.normal_(self.mask_token, std=.02)
  144. # initialize nn.Linear and nn.LayerNorm
  145. for m in self.modules():
  146. if isinstance(m, nn.Linear):
  147. # we use xavier_uniform following official JAX ViT:
  148. torch.nn.init.xavier_uniform_(m.weight)
  149. if isinstance(m, nn.Linear) and m.bias is not None:
  150. nn.init.constant_(m.bias, 0)
  151. elif isinstance(m, nn.LayerNorm):
  152. nn.init.constant_(m.bias, 0)
  153. nn.init.constant_(m.weight, 1.0)
  154. def get_posembed(self, embed_dim, grid_size, temperature=10000):
  155. scale = 2 * math.pi
  156. grid_h, grid_w = grid_size, grid_size
  157. num_pos_feats = embed_dim // 2
  158. # get grid
  159. y_embed, x_embed = torch.meshgrid([torch.arange(grid_h, dtype=torch.float32),
  160. torch.arange(grid_w, dtype=torch.float32)])
  161. # normalize grid coords
  162. y_embed = y_embed / (grid_h + 1e-6) * scale
  163. x_embed = x_embed / (grid_w + 1e-6) * scale
  164. dim_t = torch.arange(num_pos_feats, dtype=torch.float32)
  165. dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
  166. dim_t = temperature ** (2 * dim_t_)
  167. pos_x = torch.div(x_embed[..., None], dim_t)
  168. pos_y = torch.div(y_embed[..., None], dim_t)
  169. pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
  170. pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
  171. # [H, W, C] -> [N, C]
  172. pos_embed = torch.cat((pos_y, pos_x), dim=-1).view(-1, embed_dim)
  173. return pos_embed.unsqueeze(0)
  174. def forward(self, x_enc, ids_restore):
  175. # embed tokens
  176. x_enc = self.decoder_embed(x_enc)
  177. B, N_nomask, C = x_enc.shape
  178. # append mask tokens to sequence
  179. mask_tokens = self.mask_token.repeat(B, ids_restore.shape[1] - N_nomask, 1) # [B, N_mask, C], N_mask = (N-1) - N_nomask
  180. x_all = torch.cat([x_enc, mask_tokens], dim=1)
  181. x_all = torch.gather(x_all, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, C)) # unshuffle
  182. # add pos embed
  183. x_all = x_all + self.decoder_pos_embed
  184. # apply Transformer blocks
  185. for block in self.blocks:
  186. x_all = block(x_all)
  187. x_all = self.decoder_norm(x_all)
  188. # predict
  189. x_out = self.decoder_pred(x_all)
  190. return x_out
  191. # ------------------------ MAE Vision Transformer ------------------------
  192. class ViTforMaskedAutoEncoder(nn.Module):
  193. def __init__(self,
  194. encoder :MaeEncoder,
  195. decoder :MaeDecoder,
  196. ):
  197. super().__init__()
  198. self.mae_encoder = encoder
  199. self.mae_decoder = decoder
  200. def patchify(self, imgs, patch_size):
  201. """
  202. imgs: (B, 3, H, W)
  203. x: (N, L, patch_size**2 *3)
  204. """
  205. p = patch_size
  206. assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
  207. h = w = imgs.shape[2] // p
  208. x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
  209. x = torch.einsum('nchpwq->nhwpqc', x)
  210. x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
  211. return x
  212. def unpatchify(self, x, patch_size):
  213. """
  214. x: (B, N, patch_size**2 *3)
  215. imgs: (B, 3, H, W)
  216. """
  217. p = patch_size
  218. h = w = int(x.shape[1]**.5)
  219. assert h * w == x.shape[1]
  220. x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
  221. x = torch.einsum('nhwpqc->nchpwq', x)
  222. imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
  223. return imgs
  224. def compute_loss(self, x, output):
  225. """
  226. imgs: [B, 3, H, W]
  227. pred: [B, N, C], C = p*p*3
  228. mask: [B, N], 0 is keep, 1 is remove,
  229. """
  230. target = self.patchify(x, self.mae_encoder.patch_size)
  231. pred, mask = output["x_pred"], output["mask"]
  232. loss = (pred - target) ** 2
  233. loss = loss.mean(dim=-1) # [B, N], mean loss per patch
  234. loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
  235. return loss
  236. def forward(self, x):
  237. imgs = x
  238. x, mask, ids_restore = self.mae_encoder(x)
  239. x = self.mae_decoder(x, ids_restore)
  240. output = {
  241. 'x_pred': x,
  242. 'mask': mask
  243. }
  244. if self.training:
  245. loss = self.compute_loss(imgs, output)
  246. output["loss"] = loss
  247. return output
  248. # ------------------------ Model Functions ------------------------
  249. def build_vit_mae(model_name="vit_t", img_size=224, patch_size=16, img_dim=3, mask_ratio=0.75):
  250. # ---------------- MAE Encoder ----------------
  251. if model_name == "vit_t":
  252. encoder = MaeEncoder(img_size=img_size,
  253. patch_size=patch_size,
  254. in_chans=img_dim,
  255. patch_embed_dim=192,
  256. depth=12,
  257. num_heads=3,
  258. mlp_ratio=4.0,
  259. act_layer=nn.GELU,
  260. mask_ratio=mask_ratio,
  261. dropout = 0.1)
  262. if model_name == "vit_s":
  263. encoder = MaeEncoder(img_size=img_size,
  264. patch_size=patch_size,
  265. in_chans=img_dim,
  266. patch_embed_dim=384,
  267. depth=12,
  268. num_heads=6,
  269. mlp_ratio=4.0,
  270. act_layer=nn.GELU,
  271. mask_ratio=mask_ratio,
  272. dropout = 0.1)
  273. if model_name == "vit_b":
  274. encoder = MaeEncoder(img_size=img_size,
  275. patch_size=patch_size,
  276. in_chans=img_dim,
  277. patch_embed_dim=768,
  278. depth=12,
  279. num_heads=12,
  280. mlp_ratio=4.0,
  281. act_layer=nn.GELU,
  282. mask_ratio=mask_ratio,
  283. dropout = 0.1)
  284. if model_name == "vit_l":
  285. encoder = MaeEncoder(img_size=img_size,
  286. patch_size=patch_size,
  287. in_chans=img_dim,
  288. patch_embed_dim=1024,
  289. depth=24,
  290. num_heads=16,
  291. mlp_ratio=4.0,
  292. act_layer=nn.GELU,
  293. mask_ratio=mask_ratio,
  294. dropout = 0.1)
  295. if model_name == "vit_h":
  296. encoder = MaeEncoder(img_size=img_size,
  297. patch_size=patch_size,
  298. in_chans=img_dim,
  299. patch_embed_dim=1280,
  300. depth=32,
  301. num_heads=16,
  302. mlp_ratio=4.0,
  303. act_layer=nn.GELU,
  304. mask_ratio=mask_ratio,
  305. dropout = 0.1)
  306. # ---------------- MAE Decoder ----------------
  307. decoder = MaeDecoder(img_dim = img_dim,
  308. img_size=img_size,
  309. patch_size=patch_size,
  310. en_emb_dim=encoder.patch_embed_dim,
  311. de_emb_dim=512,
  312. de_num_layers=8,
  313. de_num_heads=16,
  314. qkv_bias=True,
  315. mlp_ratio=4.0,
  316. mask_ratio=mask_ratio,
  317. dropout=0.1,)
  318. return ViTforMaskedAutoEncoder(encoder, decoder)
  319. if __name__ == '__main__':
  320. import torch
  321. from thop import profile
  322. # Prepare an image as the input
  323. bs, c, h, w = 2, 3, 224, 224
  324. x = torch.randn(bs, c, h, w)
  325. patch_size = 16
  326. # Build model
  327. model = build_vit_mae(patch_size=patch_size)
  328. # Inference
  329. outputs = model(x)
  330. if "loss" in outputs:
  331. print("Loss: ", outputs["loss"].item())
  332. # Compute FLOPs & Params
  333. print('==============================')
  334. model.eval()
  335. flops, params = profile(model, inputs=(x, ), verbose=False)
  336. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  337. print('Params : {:.2f} M'.format(params / 1e6))