Browse Source

add ViT into iclab/

yjh0410 1 year ago
parent
commit
2925b14a3e
4 changed files with 374 additions and 0 deletions
  1. 3 0
      iclab/models/__init__.py
  2. 30 0
      iclab/models/vit/build.py
  3. 191 0
      iclab/models/vit/modules.py
  4. 150 0
      iclab/models/vit/vit.py

+ 3 - 0
iclab/models/__init__.py

@@ -2,6 +2,7 @@ from .elandarknet.build import build_elandarknet
 from .cspdarknet.build  import build_cspdarknet
 from .darknet.build     import build_darknet
 from .gelan.build       import build_gelan
+from .vit.build         import build_vit
 
 
 def build_model(args):
@@ -14,6 +15,8 @@ def build_model(args):
         model = build_darknet(args)
     elif 'gelan' in args.model:
         model = build_gelan(args)
+    elif 'vit' in args.model:
+        model = build_vit(args)
     else:
         raise NotImplementedError("Unknown model: {}".format(args.model))
 

+ 30 - 0
iclab/models/vit/build.py

@@ -0,0 +1,30 @@
+import torch.nn as nn
+from .vit import ImageEncoderViT, ViTForImageClassification
+
+
+def build_vit(args):
+    if args.model == "vit_t":
+        img_encoder = ImageEncoderViT(img_size=args.img_size,
+                                      patch_size=args.patch_size,
+                                      in_chans=args.img_dim,
+                                      patch_embed_dim=192,
+                                      depth=12,
+                                      num_heads=3,
+                                      mlp_ratio=4.0,
+                                      act_layer=nn.GELU,
+                                      dropout = 0.1)
+    elif args.model == "vit_s":
+        img_encoder = ImageEncoderViT(img_size=args.img_size,
+                                      patch_size=args.patch_size,
+                                      in_chans=args.img_dim,
+                                      patch_embed_dim=384,
+                                      depth=12,
+                                      num_heads=6,
+                                      mlp_ratio=4.0,
+                                      act_layer=nn.GELU,
+                                      dropout = 0.1)
+    else:
+        raise NotImplementedError("Unknown vit: {}".format(args.model))
+    
+    # Build ViT for classification
+    return ViTForImageClassification(img_encoder, args.num_classes, qkv_bias=True)

+ 191 - 0
iclab/models/vit/modules.py

@@ -0,0 +1,191 @@
+# --------------------------------------------------------------------
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from typing import Type, Tuple
+
+
+# ----------------------- Basic modules -----------------------
+class FeedFroward(nn.Module):
+    def __init__(self,
+                 embedding_dim: int,
+                 mlp_dim: int,
+                 act: Type[nn.Module] = nn.GELU,
+                 dropout: float = 0.0,
+                 ) -> None:
+        super().__init__()
+        self.fc1   = nn.Linear(embedding_dim, mlp_dim)
+        self.drop1 = nn.Dropout(dropout)
+        self.fc2   = nn.Linear(mlp_dim, embedding_dim)
+        self.drop2 = nn.Dropout(dropout)
+        self.act   = act()
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop1(x)
+        x = self.fc2(x)
+        x = self.drop2(x)
+        return x
+
+class PatchEmbed(nn.Module):
+    def __init__(self,
+                 in_chans    : int = 3,
+                 embed_dim   : int = 768,
+                 kernel_size : int = 16,
+                 padding     : int = 0,
+                 stride      : int = 16,
+                 ) -> None:
+        super().__init__()
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.proj(x)
+
+
+# ----------------------- Model modules -----------------------
+class ViTBlock(nn.Module):
+    def __init__(self,
+                 dim       :int,
+                 num_heads :int,
+                 mlp_ratio :float = 4.0,
+                 qkv_bias  :bool = True,
+                 act_layer :Type[nn.Module] = nn.GELU,
+                 dropout   :float = 0.
+                 ) -> None:
+        super().__init__()
+        # -------------- Model parameters --------------
+        self.norm1 = nn.LayerNorm(dim)
+        self.attn  = Attention(dim         = dim,
+                               qkv_bias    = qkv_bias,
+                               num_heads   = num_heads,
+                               dropout     = dropout
+                               )
+        self.norm2 = nn.LayerNorm(dim)
+        self.ffn   = FeedFroward(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        shortcut = x
+        # Attention (with prenorm)
+        x = self.norm1(x)
+        x = self.attn(x)
+        x = shortcut + x
+
+        # Feedforward (with prenorm)
+        x = x + self.ffn(self.norm2(x))
+
+        return x
+
+class Attention(nn.Module):
+    def __init__(self,
+                 dim       :int,
+                 qkv_bias  :bool  = False,
+                 num_heads :int   = 8,
+                 dropout   :float = 0.
+                 ):
+        super().__init__()
+        # --------------- Basic parameters ---------------
+        self.dim = dim
+        self.num_heads = num_heads
+        self.head_dim = dim // num_heads
+        self.scale = self.head_dim ** -0.5
+
+        # --------------- Network parameters ---------------
+        self.qkv_proj = nn.Linear(dim, dim*3, bias = qkv_bias)
+        self.attn_drop = nn.Dropout(dropout)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(dropout)
+
+    def forward(self, x):
+        bs, N, _ = x.shape
+        # ----------------- Input proj -----------------
+        qkv = self.qkv_proj(x)
+        q, k, v = torch.chunk(qkv, 3, dim=-1)
+
+        # ----------------- Multi-head Attn -----------------
+        ## [B, N, C] -> [B, N, H, C_h] -> [B, H, N, C_h]
+        q = q.view(bs, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
+        k = k.view(bs, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
+        v = v.view(bs, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
+        ## [B, H, Nq, C_h] X [B, H, C_h, Nk] = [B, H, Nq, Nk]
+        attn = q * self.scale @ k.transpose(-1, -2)
+        attn = attn.softmax(dim=-1)
+        attn = self.attn_drop(attn)
+        x = attn @ v # [B, H, Nq, C_h]
+
+        # ----------------- Output -----------------
+        x = x.permute(0, 2, 1, 3).contiguous().view(bs, N, -1)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+
+        return x
+
+
+# ----------------------- Classifier -----------------------
+class AttentionPoolingClassifier(nn.Module):
+    """
+    This code is referenced to https://github.com/apple/ml-aim/blob/main/aim/torch/layers.py
+    """
+    def __init__(
+        self,
+        in_dim      : int,
+        out_dim     : int,
+        num_heads   : int = 12,
+        qkv_bias    : bool = False,
+        num_queries : int = 1,
+    ):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = in_dim // num_heads
+        self.scale = head_dim**-0.5
+
+        self.k = nn.Linear(in_dim, in_dim, bias=qkv_bias)
+        self.v = nn.Linear(in_dim, in_dim, bias=qkv_bias)
+
+        self.cls_token = nn.Parameter(torch.randn(1, num_queries, in_dim) * 0.02)
+        self.linear = nn.Linear(in_dim, out_dim)
+        self.bn = nn.BatchNorm1d(in_dim, affine=False, eps=1e-6)
+
+        self.num_queries = num_queries
+
+    def forward(self, x: torch.Tensor):
+        B, N, C = x.shape
+        # Prenorm
+        x = self.bn(x.transpose(-2, -1)).transpose(-2, -1)
+
+        # [C] -> [B, 1, C]
+        cls_token = self.cls_token.expand(B, -1, -1)
+
+        # [B, 1, C] -> [B, 1, H, C_h] -> [B, H, 1, C_h]
+        q = cls_token.reshape(
+            B, self.num_queries, self.num_heads, C // self.num_heads
+        ).permute(0, 2, 1, 3)
+
+        # [B, N, C] -> [B, N, H, C_h] -> [B, H, N, C_h]
+        k = self.k(x).reshape(
+            B, N, self.num_heads, C // self.num_heads
+            ).permute(0, 2, 1, 3)
+        v = self.v(x).reshape(
+            B, N, self.num_heads, C // self.num_heads
+            ).permute(0, 2, 1, 3)
+
+        # Attention
+        q = q * self.scale
+        attn = q @ k.transpose(-2, -1)
+        attn = attn.softmax(dim=-1)
+
+        x_cls = (attn @ v).transpose(1, 2).reshape(B, self.num_queries, C)
+        x_cls = x_cls.mean(dim=1)
+
+        # Classify
+        out = self.linear(x_cls)
+
+        return out, x_cls

+ 150 - 0
iclab/models/vit/vit.py

@@ -0,0 +1,150 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .modules import PatchEmbed, ViTBlock, AttentionPoolingClassifier
+except:
+    from  modules import PatchEmbed, ViTBlock, AttentionPoolingClassifier
+
+
+# ---------- Vision transformer ----------
+class ImageEncoderViT(nn.Module):
+    def __init__(self,
+                 img_size: int,
+                 patch_size: int,
+                 in_chans: int,
+                 patch_embed_dim: int,
+                 depth: int,
+                 num_heads: int,
+                 mlp_ratio: float,
+                 act_layer: nn.GELU,
+                 dropout: float = 0.0,
+                 ) -> None:
+        super().__init__()
+        # ----------- Basic parameters -----------
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.patch_embed_dim = patch_embed_dim
+        self.num_heads = num_heads
+        self.num_patches = (img_size // patch_size) ** 2
+        # ----------- Model parameters -----------
+        self.patch_embed = PatchEmbed(in_chans, patch_embed_dim, patch_size, stride=patch_size)
+        self.pos_embed   = nn.Parameter(torch.zeros(1, self.num_patches, patch_embed_dim))
+        self.norm_layer  = nn.LayerNorm(patch_embed_dim)
+        self.blocks      = nn.ModuleList([
+            ViTBlock(patch_embed_dim, num_heads, mlp_ratio, True, act_layer, dropout)
+            for _ in range(depth)])
+
+        self._init_weights()
+
+    def _init_weights(self):
+        # initialize (and freeze) pos_embed by sin-cos embedding
+        pos_embed = self.get_posembed(self.pos_embed.shape[-1], int(self.num_patches**.5))
+        self.pos_embed.data.copy_(pos_embed)
+
+        # initialize nn.Linear and nn.LayerNorm
+        for m in self.modules():           
+            if isinstance(m, nn.Linear):
+                # we use xavier_uniform following official JAX ViT:
+                torch.nn.init.xavier_uniform_(m.weight)
+                if isinstance(m, nn.Linear) and m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.LayerNorm):
+                nn.init.constant_(m.bias, 0)
+                nn.init.constant_(m.weight, 1.0)
+
+    def get_posembed(self, embed_dim, grid_size, temperature=10000):
+        scale = 2 * torch.pi
+        grid_h, grid_w = grid_size, grid_size
+        num_pos_feats = embed_dim // 2
+        # get grid
+        y_embed, x_embed = torch.meshgrid([torch.arange(grid_h, dtype=torch.float32),
+                                           torch.arange(grid_w, dtype=torch.float32)])
+        # normalize grid coords
+        y_embed = y_embed / (grid_h + 1e-6) * scale
+        x_embed = x_embed / (grid_w + 1e-6) * scale
+    
+        dim_t = torch.arange(num_pos_feats, dtype=torch.float32)
+        dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
+        dim_t = temperature ** (2 * dim_t_)
+
+        pos_x = torch.div(x_embed[..., None], dim_t)
+        pos_y = torch.div(y_embed[..., None], dim_t)
+        pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
+        pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
+
+        # [H, W, C] -> [N, C]
+        pos_embed = torch.cat((pos_y, pos_x), dim=-1).view(-1, embed_dim)
+
+        return pos_embed.unsqueeze(0)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        # Patch embed
+        x = self.patch_embed(x)
+        x = x.flatten(2).permute(0, 2, 1).contiguous()
+
+        # Add pos embed
+        x = x + self.pos_embed
+
+        # Apply Transformer blocks
+        for block in self.blocks:
+            x = block(x)
+        x = self.norm_layer(x)
+
+        return x
+
+
+# ---------- Vision transformer for classification ----------
+class ViTForImageClassification(nn.Module):
+    def __init__(self,
+                 image_encoder :ImageEncoderViT,
+                 num_classes   :int   = 1000,
+                 qkv_bias      :bool  = True,
+                 ):
+        super().__init__()
+        # -------- Model parameters --------
+        self.encoder    = image_encoder
+        self.classifier = AttentionPoolingClassifier(image_encoder.patch_embed_dim,
+                                                     num_classes,
+                                                     image_encoder.num_heads,
+                                                     qkv_bias,
+                                                     num_queries=1)
+
+    def forward(self, x):
+        """
+        Inputs:
+            x: (torch.Tensor) -> [B, C, H, W]. Input image.
+        """
+        x = self.encoder(x)
+        x, x_cls = self.classifier(x)
+
+        return x
+
+
+
+if __name__=='__main__':
+    import time
+
+    # 构建ViT模型
+    img_encoder = ImageEncoderViT(img_size=224,
+                                  patch_size=16,
+                                  in_chans=3,
+                                  patch_embed_dim=192,
+                                  depth=12,
+                                  num_heads=3,
+                                  mlp_ratio=4.0,
+                                  act_layer=nn.GELU,
+                                  dropout = 0.1)
+    model = ViTForImageClassification(img_encoder, num_classes=10, qkv_bias=True)
+
+    # 打印模型结构
+    print(model)
+
+    # 随即成生数据
+    x = torch.randn(1, 3, 224, 224)
+
+    # 模型前向推理
+    t0 = time.time()
+    output = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)