Browse Source

add dl tutorial

yjh0410 1 year ago
parent
commit
2a3974d075
84 changed files with 3104 additions and 4562 deletions
  1. 0 61
      iclab/README.md
  2. 0 36
      iclab/data/__init__.py
  3. 0 109
      iclab/data/imagenet.py
  4. 0 23
      iclab/models/__init__.py
  5. 0 18
      iclab/models/cspdarknet/build.py
  6. 0 170
      iclab/models/cspdarknet/cspdarknet.py
  7. 0 136
      iclab/models/cspdarknet/modules.py
  8. 0 18
      iclab/models/darknet/build.py
  9. 0 170
      iclab/models/darknet/darknet.py
  10. 0 130
      iclab/models/darknet/modules.py
  11. 0 16
      iclab/models/elandarknet/build.py
  12. 0 171
      iclab/models/elandarknet/elandarknet.py
  13. 0 135
      iclab/models/elandarknet/modules.py
  14. 0 12
      iclab/models/gelan/build.py
  15. 0 233
      iclab/models/gelan/gelan.py
  16. 0 312
      iclab/models/gelan/modules.py
  17. 0 332
      iclab/train.py
  18. 0 164
      iclab/train.sh
  19. 0 18
      iclab/utils/com_flops_params.py
  20. 0 165
      iclab/utils/distributed_utils.py
  21. 0 52
      iclab/utils/ema.py
  22. 0 24
      iclab/utils/optimzer.py
  23. 2 1
      image_classification/.gitignore
  24. 27 0
      image_classification/README.md
  25. 42 0
      image_classification/data/__init__.py
  26. 2 3
      image_classification/data/cifar.py
  27. 17 39
      image_classification/data/custom.py
  28. 2 2
      image_classification/data/mnist.py
  29. 24 48
      image_classification/engine.py
  30. 204 0
      image_classification/main.py
  31. 20 0
      image_classification/models/__init__.py
  32. 16 0
      image_classification/models/convnet/build.py
  33. 120 0
      image_classification/models/convnet/convnet.py
  34. 86 0
      image_classification/models/convnet/modules.py
  35. 14 0
      image_classification/models/mlp/build.py
  36. 56 0
      image_classification/models/mlp/mlp.py
  37. 48 0
      image_classification/models/mlp/modules.py
  38. 21 0
      image_classification/models/resnet/build.py
  39. 164 0
      image_classification/models/resnet/modules.py
  40. 110 0
      image_classification/models/resnet/resnet.py
  41. 0 0
      image_classification/models/vit/build.py
  42. 0 0
      image_classification/models/vit/modules.py
  43. 0 0
      image_classification/models/vit/vit.py
  44. 0 0
      image_classification/requirements.txt
  45. 0 0
      image_classification/utils/__init__.py
  46. 38 0
      image_classification/utils/lr_scheduler.py
  47. 198 0
      image_classification/utils/misc.py
  48. 34 0
      image_classification/utils/optimzer.py
  49. 12 0
      masked_image_modeling/.gitignore
  50. 73 0
      masked_image_modeling/README.md
  51. 38 0
      masked_image_modeling/data/__init__.py
  52. 65 0
      masked_image_modeling/data/cifar.py
  53. 87 0
      masked_image_modeling/data/custom.py
  54. 107 0
      masked_image_modeling/engine_finetune.py
  55. 64 0
      masked_image_modeling/engine_pretrain.py
  56. 175 0
      masked_image_modeling/main_finetune.py
  57. 211 0
      masked_image_modeling/main_pretrain.py
  58. 9 0
      masked_image_modeling/models/__init__.py
  59. 3 0
      masked_image_modeling/models/vit/__init__.py
  60. 45 0
      masked_image_modeling/models/vit/build.py
  61. 186 0
      masked_image_modeling/models/vit/modules.py
  62. 96 0
      masked_image_modeling/models/vit/pos_embed.py
  63. 180 0
      masked_image_modeling/models/vit/vit.py
  64. 28 0
      masked_image_modeling/models/vit/vit_cls.py
  65. 399 0
      masked_image_modeling/models/vit/vit_mae.py
  66. 5 0
      masked_image_modeling/requirements.txt
  67. 0 0
      masked_image_modeling/utils/__init__.py
  68. 1 3
      masked_image_modeling/utils/lr_scheduler.py
  69. 50 110
      masked_image_modeling/utils/misc.py
  70. 25 0
      masked_image_modeling/utils/optimizer.py
  71. 0 3
      yolo/config/__init__.py
  72. 0 196
      yolo/config/yolov8_e2e_config.py
  73. 0 4
      yolo/models/__init__.py
  74. 0 60
      yolo/models/yolov8_e2e/README.md
  75. 0 24
      yolo/models/yolov8_e2e/build.py
  76. 0 204
      yolo/models/yolov8_e2e/loss.py
  77. 0 202
      yolo/models/yolov8_e2e/matcher.py
  78. 0 181
      yolo/models/yolov8_e2e/yolov8_backbone.py
  79. 0 172
      yolo/models/yolov8_e2e/yolov8_basic.py
  80. 0 179
      yolo/models/yolov8_e2e/yolov8_e2e.py
  81. 0 179
      yolo/models/yolov8_e2e/yolov8_head.py
  82. 0 85
      yolo/models/yolov8_e2e/yolov8_neck.py
  83. 0 152
      yolo/models/yolov8_e2e/yolov8_pafpn.py
  84. 0 210
      yolo/models/yolov8_e2e/yolov8_pred.py

+ 0 - 61
iclab/README.md

@@ -1,61 +0,0 @@
-# General Image Classification Laboratory
-
-
-## Train a CNN
-We have kindly provided a bash script `train.sh` to train the models. You can modify some hyperparameters in the script file according to your own needs.
-
-For example, we are going to use 8 GPUs to train `ELANDarkNet-S` designed in this repo, so we can use the following command:
-
-```Shell
-bash train.sh elandarknet_s imagenet_1k path/to/imagnet_1k 8 1699 None
-```
-
-## Evaluate a CNN
-- Evaluate the `top1 & top5` accuracy of `ViT-Tiny` on ImageNet-1K dataset:
-```Shell
-python main.py --cuda -dataset imagenet_1k --root path/to/imagnet_1k -m elandarknet_s --batch_size 256 --img_size 224 --eval --resume path/to/elandarknet_s.pth
-```
-
-
-## Experimental results
-Tips:
-- **Weak augmentation:** includes `random hflip` and `random crop resize`.
-- **Strong augmentation:** includes `mixup`, `cutmix`, `rand aug`, `random erase` and so on. However, we don't use the strong augmentation.
-- The `AdamW` with `weight decay = 0.05` and `base lr = 4e-3 (for bs of 4096)` is deployed as the optimzier, and the `CosineAnnealingLR` is deployed as the lr scheduler, where the `min lr` is set to 1e-6.
-
-### ImageNet-1K
-* Modified DarkNet (Yolov3's DarkNet with width and depth scaling factors)
-
-|    Model      | Augment | Batch | Epoch | size | acc@1 | GFLOPs | Params |  Weight |
-|---------------|---------|-------|-------|------|-------|--------|--------|---------|
-| DarkNet-S     |   weak  |  4096 |  100  | 224  |  68.5 |  1.6   |  4.6 M | [ckpt](https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/darknet_s_in1k_68.5.pth) |
-| DarkNet-M     |   weak  |  4096 |  100  | 224  |       |        |        |  |
-| DarkNet-L     |   weak  |  4096 |  100  | 224  |       |        |        |  |
-| DarkNet-X     |   weak  |  4096 |  100  | 224  |       |        |        |  |
-
-* Modified CSPDarkNet (Yolov5's DarkNet with width and depth scaling factors)
-
-|    Model      | Augment | Batch | Epoch | size | acc@1 | GFLOPs | Params |  Weight |
-|---------------|---------|-------|-------|------|-------|--------|--------|---------|
-| CSPDarkNet-S  |   weak  |  4096 |  100  | 224  |  70.2 |  1.3   | 4.0 M  | [ckpt](https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/cspdarknet_s_in1k_70.2.pth) |
-| CSPDarkNet-M  |   weak  |  4096 |  100  | 224  |       |        |        |  |
-| CSPDarkNet-L  |   weak  |  4096 |  100  | 224  |       |        |        |  |
-| CSPDarkNet-X  |   weak  |  4096 |  100  | 224  |       |        |        |  |
-
-* ElANDarkNet (Yolov8's backbone)
-
-|         Model          | Augment | Batch | Epoch | size | acc@1 | GFLOPs | Params  |  Weight |
-|------------------------|---------|-------|-------|------|-------|--------|---------|---------|
-| ElANDarkNet-N      |   weak  |  4096 |  100  | 224  |  62.1 |  0.38  | 1.36 M  | [ckpt](https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/elandarknet_n_in1k_62.1.pth) |
-| ElANDarkNet-S      |   weak  |  4096 |  100  | 224  |  71.3 |  1.48  | 4.94 M  | [ckpt](https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/elandarknet_s_in1k_71.3.pth) |
-| ElANDarkNet-M      |   weak  |  4096 |  100  | 224  |       |  4.67  | 11.60 M |  |
-| ElANDarkNet-L      |   weak  |  4096 |  100  | 224  |       |  10.47 | 19.66 M |  |
-| ElANDarkNet-X      |   weak  |  4096 |  100  | 224  |       |  20.56 | 37.86 M |  |
-
-
-* GELAN (Proposed by YOLOv9)
-
-|     Model     | Augment | Batch | Epoch | size | acc@1 | GFLOPs | Params  |  Weight |
-|---------------|---------|-------|-------|------|-------|--------|---------|---------|
-| GELAN-S       |   weak  |  4096 |  100  | 224  | 68.4  |  0.9   | 1.9 M   | [ckpt](https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/gelan_s_in1k_68.4.pth) |
-| GELAN-C       |   weak  |  4096 |  100  | 224  |   |  5.2   | 8.8 M   | [ckpt]()|

+ 0 - 36
iclab/data/__init__.py

@@ -1,36 +0,0 @@
-import torch.utils.data as data
-
-from .cifar import CifarDataset
-from .mnist import MnistDataset
-from .imagenet import ImageNet1KDataset
-from .custom import CustomDataset
-
-
-def build_dataset(args, transform=None, is_train=False):
-    if args.dataset == 'cifar10':
-        args.num_classes = 10
-        args.img_dim = 3
-        return CifarDataset(is_train, transform)
-    elif args.dataset == 'mnist':
-        args.num_classes = 10
-        args.img_dim = 1
-        return MnistDataset(is_train, transform)
-    elif args.dataset == 'imagenet_1k':
-        args.num_classes = 1000
-        args.img_dim = 3
-        return ImageNet1KDataset(args, is_train, transform)
-    elif args.dataset == 'custom':
-        assert args.num_classes is not None and isinstance(args.num_classes, int)
-        args.img_dim = 3
-        return CustomDataset(args, is_train, transform)
-    
-
-def build_dataloader(args, dataset, is_train=False):
-    if is_train:
-        sampler = data.distributed.DistributedSampler(dataset) if args.distributed else data.RandomSampler(dataset)
-        batch_sampler_train = data.BatchSampler(sampler, args.batch_size // args.world_size, drop_last=True if is_train else False)
-        dataloader = data.DataLoader(dataset, batch_sampler=batch_sampler_train, num_workers=args.num_workers, pin_memory=True)
-    else:
-        dataloader = data.DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
-
-    return dataloader

+ 0 - 109
iclab/data/imagenet.py

@@ -1,109 +0,0 @@
-import os
-import PIL
-import numpy as np
-from timm.data import create_transform
-import torch.utils.data as data
-import torchvision.transforms as T
-from torchvision.datasets import ImageFolder
-import torchvision.transforms as transforms
-
-
-class ImageNet1KDataset(data.Dataset):
-    def __init__(self, args, is_train=False, transform=None):
-        super().__init__()
-        # ----------------- basic parameters -----------------
-        self.args = args
-        self.is_train   = is_train
-        self.pixel_mean = [0.485, 0.456, 0.406]
-        self.pixel_std  = [0.229, 0.224, 0.225]
-        print("Pixel mean: {}".format(self.pixel_mean))
-        print("Pixel std:  {}".format(self.pixel_std))
-        self.image_set = 'train' if is_train else 'val'
-        self.data_path = os.path.join(args.root, self.image_set)
-        # ----------------- dataset & transforms -----------------
-        self.transform = transform if transform is not None else self.build_transform(args)
-        self.dataset = ImageFolder(root=self.data_path, transform=self.transform)
-
-    def __len__(self):
-        return len(self.dataset)
-    
-    def __getitem__(self, index):
-        image, target = self.dataset[index]
-
-        return image, target
-    
-    def pull_image(self, index):
-        # laod data
-        image, target = self.dataset[index]
-
-        # denormalize image
-        image = image.permute(1, 2, 0).numpy()
-        image = (image * self.pixel_std + self.pixel_mean) * 255.
-        image = image.astype(np.uint8)
-        image = image.copy()
-
-        return image, target
-
-    def build_transform(self, args):
-        if self.is_train:
-            transforms = create_transform(input_size    = args.img_size,
-                                          is_training   = True,
-                                          color_jitter  = args.color_jitter,
-                                          auto_augment  = args.aa,
-                                          interpolation = 'bicubic',
-                                          re_prob       = args.reprob,
-                                          re_mode       = args.remode,
-                                          re_count      = args.recount,
-                                          mean          = self.pixel_mean,
-                                          std           = self.pixel_std,
-                                          )
-        else:
-            t = []
-            if args.img_size <= 224:
-                crop_pct = 224 / 256
-            else:
-                crop_pct = 1.0
-            size = int(args.img_size / crop_pct)
-            t.append(
-                T.Resize(size, interpolation=PIL.Image.BICUBIC),  # to maintain same ratio w.r.t. 224 images
-            )
-            t.append(T.CenterCrop(args.img_size))
-            t.append(T.ToTensor())
-            t.append(T.Normalize(self.pixel_mean, self.pixel_std))
-            transforms = T.Compose(t)
-
-        return transforms
-
-
-if __name__ == "__main__":
-    import cv2
-    import torch
-    import argparse
-    
-    parser = argparse.ArgumentParser(description='ImageNet-Dataset')
-
-    # opt
-    parser.add_argument('--root', default='/mnt/share/ssd2/dataset/imagenet/',
-                        help='data root')
-    parser.add_argument('--img_size', default=224, type=int,
-                        help='input image size.')
-    args = parser.parse_args()
-
-    # Transforms
-    train_transform = transforms.Compose([
-            transforms.RandomResizedCrop(args.img_size, scale=(0.2, 1.0), interpolation=3),  # 3 is bicubic
-            transforms.RandomHorizontalFlip(),
-            transforms.ToTensor(),
-            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
-  
-    # Dataset
-    dataset = ImageNet1KDataset(args, is_train=True)  
-    print('Dataset size: ', len(dataset))
-
-    for i in range(1000):
-        image, target = dataset.pull_image(i)
-        # to BGR
-        image = image[..., (2, 1, 0)]
-
-        cv2.imshow('image', image)
-        cv2.waitKey(0)

+ 0 - 23
iclab/models/__init__.py

@@ -1,23 +0,0 @@
-from .elandarknet.build import build_elandarknet
-from .cspdarknet.build  import build_cspdarknet
-from .darknet.build     import build_darknet
-from .gelan.build       import build_gelan
-from .vit.build         import build_vit
-
-
-def build_model(args):
-    # --------------------------- ResNet series ---------------------------
-    if   'elandarknet' in args.model:
-        model = build_elandarknet(args)
-    elif 'cspdarknet' in args.model:
-        model = build_cspdarknet(args)
-    elif 'darknet' in args.model:
-        model = build_darknet(args)
-    elif 'gelan' in args.model:
-        model = build_gelan(args)
-    elif 'vit' in args.model:
-        model = build_vit(args)
-    else:
-        raise NotImplementedError("Unknown model: {}".format(args.model))
-
-    return model

+ 0 - 18
iclab/models/cspdarknet/build.py

@@ -1,18 +0,0 @@
-from .cspdarknet import cspdarknet_n, cspdarknet_s, cspdarknet_m, cspdarknet_l, cspdarknet_x
-
-def build_cspdarknet(args):
-    # build vit model
-    if   args.model == 'cspdarknet_n':
-        model = cspdarknet_n(args.img_dim, args.num_classes)
-    elif args.model == 'cspdarknet_s':
-        model = cspdarknet_s(args.img_dim, args.num_classes)
-    elif args.model == 'cspdarknet_m':
-        model = cspdarknet_m(args.img_dim, args.num_classes)
-    elif args.model == 'cspdarknet_l':
-        model = cspdarknet_l(args.img_dim, args.num_classes)
-    elif args.model == 'cspdarknet_x':
-        model = cspdarknet_x(args.img_dim, args.num_classes)
-    else:
-        raise NotImplementedError("Unknown cspdarknet: {}".format(args.model))
-    
-    return model

+ 0 - 170
iclab/models/cspdarknet/cspdarknet.py

@@ -1,170 +0,0 @@
-import torch
-import torch.nn as nn
-
-try:
-    from .modules import BasicConv, CSPBlock
-except:
-    from  modules import BasicConv, CSPBlock
-
-
-# ---------------------------- CSPDarkNet ----------------------------
-# CSPDarkNet
-class CSPDarkNet(nn.Module):
-    def __init__(self, img_dim=3, width=1.0, depth=1.0, act_type='silu', norm_type='BN', depthwise=False, num_classes=1000):
-        super(CSPDarkNet, self).__init__()
-        # ---------------- Basic parameters ----------------
-        self.width_factor = width
-        self.depth_factor = depth
-        self.feat_dims = [round(64 * width),
-                          round(128 * width),
-                          round(256 * width),
-                          round(512 * width),
-                          round(1024 * width)
-                          ]
-
-        # ---------------- Model parameters ----------------
-        ## P1/2
-        self.layer_1 = BasicConv(img_dim, self.feat_dims[0],
-                                 kernel_size=6, padding=2, stride=2,
-                                 act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        
-        ## P2/4
-        self.layer_2 = nn.Sequential(
-            BasicConv(self.feat_dims[0], self.feat_dims[1],
-                      kernel_size=3, padding=1, stride=2,
-                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            CSPBlock(self.feat_dims[1],
-                     self.feat_dims[1],
-                     num_blocks   = round(3*depth),
-                     expand_ratio = 0.5,
-                     shortcut     = True,
-                     act_type     = act_type,
-                     norm_type    = norm_type,
-                     depthwise    = depthwise)
-        )
-        # P3/8
-        self.layer_3 = nn.Sequential(
-            BasicConv(self.feat_dims[1], self.feat_dims[2],
-                      kernel_size=3, padding=1, stride=2,
-                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            CSPBlock(self.feat_dims[2],
-                     self.feat_dims[2],
-                     num_blocks   = round(9*depth),
-                     expand_ratio = 0.5,
-                     shortcut     = True,
-                     act_type     = act_type,
-                     norm_type    = norm_type,
-                     depthwise    = depthwise)
-        )
-        # P4/16
-        self.layer_4 = nn.Sequential(
-            BasicConv(self.feat_dims[2], self.feat_dims[3],
-                      kernel_size=3, padding=1, stride=2,
-                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            CSPBlock(self.feat_dims[3],
-                     self.feat_dims[3],
-                     num_blocks   = round(9*depth),
-                     expand_ratio = 0.5,
-                     shortcut     = True,
-                     act_type     = act_type,
-                     norm_type    = norm_type,
-                     depthwise    = depthwise)
-        )
-        # P5/32
-        self.layer_5 = nn.Sequential(
-            BasicConv(self.feat_dims[3], self.feat_dims[4],
-                      kernel_size=3, padding=1, stride=2,
-                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            CSPBlock(self.feat_dims[4],
-                     self.feat_dims[4],
-                     num_blocks   = round(3*depth),
-                     expand_ratio = 0.5,
-                     shortcut     = True,
-                     act_type     = act_type,
-                     norm_type    = norm_type,
-                     depthwise    = depthwise)
-        )
-        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
-        self.fc = nn.Linear(self.feat_dims[4], num_classes)
-
-
-    def forward(self, x):
-        c1 = self.layer_1(x)
-        c2 = self.layer_2(c1)
-        c3 = self.layer_3(c2)
-        c4 = self.layer_4(c3)
-        c5 = self.layer_5(c4)
-
-        c5 = self.avgpool(c5)
-        c5 = torch.flatten(c5, 1)
-        c5 = self.fc(c5)
-
-        return c5
-
-
-# ---------------------------- Functions ----------------------------
-## build ELAN-Net
-# ------------------------ Model Functions ------------------------
-def cspdarknet_n(img_dim=3, num_classes=1000) -> CSPDarkNet:
-    return CSPDarkNet(img_dim=img_dim,
-                       width=0.25,
-                       depth=0.34,
-                       act_type='silu',
-                       norm_type='BN',
-                       depthwise=False,
-                       num_classes=num_classes
-                       )
-
-def cspdarknet_s(img_dim=3, num_classes=1000) -> CSPDarkNet:
-    return CSPDarkNet(img_dim=img_dim,
-                       width=0.50,
-                       depth=0.34,
-                       act_type='silu',
-                       norm_type='BN',
-                       depthwise=False,
-                       num_classes=num_classes
-                       )
-
-def cspdarknet_m(img_dim=3, num_classes=1000) -> CSPDarkNet:
-    return CSPDarkNet(img_dim=img_dim,
-                       width=0.75,
-                       depth=0.67,
-                       act_type='silu',
-                       norm_type='BN',
-                       depthwise=False,
-                       num_classes=num_classes
-                       )
-
-def cspdarknet_l(img_dim=3, num_classes=1000) -> CSPDarkNet:
-    return CSPDarkNet(img_dim=img_dim,
-                       width=1.0,
-                       depth=1.0,
-                       act_type='silu',
-                       norm_type='BN',
-                       depthwise=False,
-                       num_classes=num_classes
-                       )
-
-def cspdarknet_x(img_dim=3, num_classes=1000) -> CSPDarkNet:
-    return CSPDarkNet(img_dim=img_dim,
-                       width=1.25,
-                       depth=1.34,
-                       act_type='silu',
-                       norm_type='BN',
-                       depthwise=False,
-                       num_classes=num_classes
-                       )
-
-
-if __name__ == '__main__':
-    import torch
-    from thop import profile
-
-    # build model
-    model = cspdarknet_s()
-
-    x = torch.randn(1, 3, 224, 224)
-    print('==============================')
-    flops, params = profile(model, inputs=(x, ), verbose=False)
-    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))

+ 0 - 136
iclab/models/cspdarknet/modules.py

@@ -1,136 +0,0 @@
-import torch
-import torch.nn as nn
-from   typing import List
-
-# --------------------- Basic modules ---------------------
-def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
-    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
-
-    return conv
-
-def get_activation(act_type=None):
-    if act_type == 'relu':
-        return nn.ReLU(inplace=True)
-    elif act_type == 'lrelu':
-        return nn.LeakyReLU(0.1, inplace=True)
-    elif act_type == 'mish':
-        return nn.Mish(inplace=True)
-    elif act_type == 'silu':
-        return nn.SiLU(inplace=True)
-    elif act_type is None:
-        return nn.Identity()
-    else:
-        raise NotImplementedError
-        
-def get_norm(norm_type, dim):
-    if norm_type == 'BN':
-        return nn.BatchNorm2d(dim)
-    elif norm_type == 'GN':
-        return nn.GroupNorm(num_groups=32, num_channels=dim)
-    elif norm_type is None:
-        return nn.Identity()
-    else:
-        raise NotImplementedError
-
-class BasicConv(nn.Module):
-    def __init__(self, 
-                 in_dim,                   # in channels
-                 out_dim,                  # out channels 
-                 kernel_size=1,            # kernel size 
-                 padding=0,                # padding
-                 stride=1,                 # padding
-                 dilation=1,               # dilation
-                 act_type  :str = 'lrelu', # activation
-                 norm_type :str = 'BN',    # normalization
-                 depthwise :bool = False
-                ):
-        super(BasicConv, self).__init__()
-        self.depthwise = depthwise
-        if not depthwise:
-            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=1)
-            self.norm = get_norm(norm_type, out_dim)
-        else:
-            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim)
-            self.norm1 = get_norm(norm_type, in_dim)
-            self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1)
-            self.norm2 = get_norm(norm_type, out_dim)
-        self.act  = get_activation(act_type)
-
-    def forward(self, x):
-        if not self.depthwise:
-            return self.act(self.norm(self.conv(x)))
-        else:
-            # Depthwise conv
-            x = self.norm1(self.conv1(x))
-            # Pointwise conv
-            x = self.act(self.norm2(self.conv2(x)))
-            return x
-
-
-# ---------------------------- Basic Modules ----------------------------
-class YoloBottleneck(nn.Module):
-    def __init__(self,
-                 in_dim       :int,
-                 out_dim      :int,
-                 kernel_size  :List  = [1, 3],
-                 expand_ratio :float = 0.5,
-                 shortcut     :bool  = False,
-                 act_type     :str   = 'silu',
-                 norm_type    :str   = 'BN',
-                 depthwise    :bool  = False,
-                 ) -> None:
-        super(YoloBottleneck, self).__init__()
-        inter_dim = int(out_dim * expand_ratio)
-        # ----------------- Network setting -----------------
-        self.conv_layer1 = BasicConv(in_dim, inter_dim,
-                                     kernel_size=kernel_size[0], padding=kernel_size[0]//2, stride=1,
-                                     act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        self.conv_layer2 = BasicConv(inter_dim, out_dim,
-                                     kernel_size=kernel_size[1], padding=kernel_size[1]//2, stride=1,
-                                     act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        self.shortcut = shortcut and in_dim == out_dim
-
-    def forward(self, x):
-        h = self.conv_layer2(self.conv_layer1(x))
-
-        return x + h if self.shortcut else h
-
-class CSPBlock(nn.Module):
-    def __init__(self,
-                 in_dim,
-                 out_dim,
-                 num_blocks   :int   = 1,
-                 expand_ratio :float = 0.5,
-                 shortcut     :bool  = False,
-                 act_type     :str   = 'silu',
-                 norm_type    :str   = 'BN',
-                 depthwise    :bool  = False,
-                 ):
-        super(CSPBlock, self).__init__()
-        # ---------- Basic parameters ----------
-        self.num_blocks = num_blocks
-        self.expand_ratio = expand_ratio
-        self.shortcut = shortcut
-        inter_dim = round(out_dim * expand_ratio)
-        # ---------- Model parameters ----------
-        self.conv_layer_1 = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
-        self.conv_layer_2 = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
-        self.conv_layer_3 = BasicConv(inter_dim * 2, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
-        self.module       = nn.Sequential(*[YoloBottleneck(inter_dim,
-                                                           inter_dim,
-                                                           kernel_size  = [1, 3],
-                                                           expand_ratio = 1.0,
-                                                           shortcut     = shortcut,
-                                                           act_type     = act_type,
-                                                           norm_type    = norm_type,
-                                                           depthwise    = depthwise)
-                                                           for _ in range(num_blocks)
-                                                           ])
-
-    def forward(self, x):
-        x1 = self.conv_layer_1(x)
-        x2 = self.module(self.conv_layer_2(x))
-        out = self.conv_layer_3(torch.cat([x1, x2], dim=1))
-
-        return out
-    

+ 0 - 18
iclab/models/darknet/build.py

@@ -1,18 +0,0 @@
-from .darknet import darknet_n, darknet_s, darknet_m, darknet_l, darknet_x
-
-def build_darknet(args):
-    # build vit model
-    if   args.model == 'darknet_n':
-        model = darknet_n(args.img_dim, args.num_classes)
-    elif args.model == 'darknet_s':
-        model = darknet_s(args.img_dim, args.num_classes)
-    elif args.model == 'darknet_m':
-        model = darknet_m(args.img_dim, args.num_classes)
-    elif args.model == 'darknet_l':
-        model = darknet_l(args.img_dim, args.num_classes)
-    elif args.model == 'darknet_x':
-        model = darknet_x(args.img_dim, args.num_classes)
-    else:
-        raise NotImplementedError("Unknown darknet: {}".format(args.model))
-    
-    return model

+ 0 - 170
iclab/models/darknet/darknet.py

@@ -1,170 +0,0 @@
-import torch
-import torch.nn as nn
-
-try:
-    from .modules import BasicConv, ResBlock
-except:
-    from  modules import BasicConv, ResBlock
-
-
-# ---------------------------- DarkNet ----------------------------
-# Modified DarkNet
-class DarkNet(nn.Module):
-    def __init__(self, img_dim=3, width=1.0, depth=1.0, act_type='silu', norm_type='BN', depthwise=False, num_classes=1000):
-        super(DarkNet, self).__init__()
-        # ---------------- Basic parameters ----------------
-        self.width_factor = width
-        self.depth_factor = depth
-        self.feat_dims = [round(64 * width),
-                          round(128 * width),
-                          round(256 * width),
-                          round(512 * width),
-                          round(1024 * width)
-                          ]
-
-        # ---------------- Model parameters ----------------
-        ## P1/2
-        self.layer_1 = BasicConv(img_dim, self.feat_dims[0],
-                                 kernel_size=6, padding=2, stride=2,
-                                 act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        
-        ## P2/4
-        self.layer_2 = nn.Sequential(
-            BasicConv(self.feat_dims[0], self.feat_dims[1],
-                      kernel_size=3, padding=1, stride=2,
-                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            ResBlock(self.feat_dims[1],
-                     self.feat_dims[1],
-                     num_blocks   = round(3*depth),
-                     expand_ratio = 0.5,
-                     shortcut     = True,
-                     act_type     = act_type,
-                     norm_type    = norm_type,
-                     depthwise    = depthwise)
-        )
-        # P3/8
-        self.layer_3 = nn.Sequential(
-            BasicConv(self.feat_dims[1], self.feat_dims[2],
-                      kernel_size=3, padding=1, stride=2,
-                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            ResBlock(self.feat_dims[2],
-                     self.feat_dims[2],
-                     num_blocks   = round(9*depth),
-                     expand_ratio = 0.5,
-                     shortcut     = True,
-                     act_type     = act_type,
-                     norm_type    = norm_type,
-                     depthwise    = depthwise)
-        )
-        # P4/16
-        self.layer_4 = nn.Sequential(
-            BasicConv(self.feat_dims[2], self.feat_dims[3],
-                      kernel_size=3, padding=1, stride=2,
-                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            ResBlock(self.feat_dims[3],
-                     self.feat_dims[3],
-                     num_blocks   = round(9*depth),
-                     expand_ratio = 0.5,
-                     shortcut     = True,
-                     act_type     = act_type,
-                     norm_type    = norm_type,
-                     depthwise    = depthwise)
-        )
-        # P5/32
-        self.layer_5 = nn.Sequential(
-            BasicConv(self.feat_dims[3], self.feat_dims[4],
-                      kernel_size=3, padding=1, stride=2,
-                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            ResBlock(self.feat_dims[4],
-                     self.feat_dims[4],
-                     num_blocks   = round(3*depth),
-                     expand_ratio = 0.5,
-                     shortcut     = True,
-                     act_type     = act_type,
-                     norm_type    = norm_type,
-                     depthwise    = depthwise)
-        )
-        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
-        self.fc = nn.Linear(self.feat_dims[4], num_classes)
-
-
-    def forward(self, x):
-        c1 = self.layer_1(x)
-        c2 = self.layer_2(c1)
-        c3 = self.layer_3(c2)
-        c4 = self.layer_4(c3)
-        c5 = self.layer_5(c4)
-
-        c5 = self.avgpool(c5)
-        c5 = torch.flatten(c5, 1)
-        c5 = self.fc(c5)
-
-        return c5
-
-
-# ---------------------------- Functions ----------------------------
-## build ELAN-Net
-# ------------------------ Model Functions ------------------------
-def darknet_n(img_dim=3, num_classes=1000) -> DarkNet:
-    return DarkNet(img_dim=img_dim,
-                       width=0.25,
-                       depth=0.34,
-                       act_type='silu',
-                       norm_type='BN',
-                       depthwise=False,
-                       num_classes=num_classes
-                       )
-
-def darknet_s(img_dim=3, num_classes=1000) -> DarkNet:
-    return DarkNet(img_dim=img_dim,
-                       width=0.50,
-                       depth=0.34,
-                       act_type='silu',
-                       norm_type='BN',
-                       depthwise=False,
-                       num_classes=num_classes
-                       )
-
-def darknet_m(img_dim=3, num_classes=1000) -> DarkNet:
-    return DarkNet(img_dim=img_dim,
-                       width=0.75,
-                       depth=0.67,
-                       act_type='silu',
-                       norm_type='BN',
-                       depthwise=False,
-                       num_classes=num_classes
-                       )
-
-def darknet_l(img_dim=3, num_classes=1000) -> DarkNet:
-    return DarkNet(img_dim=img_dim,
-                       width=1.0,
-                       depth=1.0,
-                       act_type='silu',
-                       norm_type='BN',
-                       depthwise=False,
-                       num_classes=num_classes
-                       )
-
-def darknet_x(img_dim=3, num_classes=1000) -> DarkNet:
-    return DarkNet(img_dim=img_dim,
-                       width=1.25,
-                       depth=1.34,
-                       act_type='silu',
-                       norm_type='BN',
-                       depthwise=False,
-                       num_classes=num_classes
-                       )
-
-
-if __name__ == '__main__':
-    import torch
-    from thop import profile
-
-    # build model
-    model = darknet_s()
-
-    x = torch.randn(1, 3, 224, 224)
-    print('==============================')
-    flops, params = profile(model, inputs=(x, ), verbose=False)
-    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))

+ 0 - 130
iclab/models/darknet/modules.py

@@ -1,130 +0,0 @@
-import torch
-import torch.nn as nn
-from   typing import List
-
-# --------------------- Basic modules ---------------------
-def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
-    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
-
-    return conv
-
-def get_activation(act_type=None):
-    if act_type == 'relu':
-        return nn.ReLU(inplace=True)
-    elif act_type == 'lrelu':
-        return nn.LeakyReLU(0.1, inplace=True)
-    elif act_type == 'mish':
-        return nn.Mish(inplace=True)
-    elif act_type == 'silu':
-        return nn.SiLU(inplace=True)
-    elif act_type is None:
-        return nn.Identity()
-    else:
-        raise NotImplementedError
-        
-def get_norm(norm_type, dim):
-    if norm_type == 'BN':
-        return nn.BatchNorm2d(dim)
-    elif norm_type == 'GN':
-        return nn.GroupNorm(num_groups=32, num_channels=dim)
-    elif norm_type is None:
-        return nn.Identity()
-    else:
-        raise NotImplementedError
-
-class BasicConv(nn.Module):
-    def __init__(self, 
-                 in_dim,                   # in channels
-                 out_dim,                  # out channels 
-                 kernel_size=1,            # kernel size 
-                 padding=0,                # padding
-                 stride=1,                 # padding
-                 dilation=1,               # dilation
-                 act_type  :str = 'lrelu', # activation
-                 norm_type :str = 'BN',    # normalization
-                 depthwise :bool = False
-                ):
-        super(BasicConv, self).__init__()
-        self.depthwise = depthwise
-        if not depthwise:
-            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=1)
-            self.norm = get_norm(norm_type, out_dim)
-        else:
-            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim)
-            self.norm1 = get_norm(norm_type, in_dim)
-            self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1)
-            self.norm2 = get_norm(norm_type, out_dim)
-        self.act  = get_activation(act_type)
-
-    def forward(self, x):
-        if not self.depthwise:
-            return self.act(self.norm(self.conv(x)))
-        else:
-            # Depthwise conv
-            x = self.norm1(self.conv1(x))
-            # Pointwise conv
-            x = self.act(self.norm2(self.conv2(x)))
-            return x
-
-
-# ---------------------------- Basic Modules ----------------------------
-class YoloBottleneck(nn.Module):
-    def __init__(self,
-                 in_dim       :int,
-                 out_dim      :int,
-                 kernel_size  :List  = [1, 3],
-                 expand_ratio :float = 0.5,
-                 shortcut     :bool  = False,
-                 act_type     :str   = 'silu',
-                 norm_type    :str   = 'BN',
-                 depthwise    :bool  = False,
-                 ) -> None:
-        super(YoloBottleneck, self).__init__()
-        inter_dim = int(out_dim * expand_ratio)
-        # ----------------- Network setting -----------------
-        self.conv_layer1 = BasicConv(in_dim, inter_dim,
-                                     kernel_size=kernel_size[0], padding=kernel_size[0]//2, stride=1,
-                                     act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        self.conv_layer2 = BasicConv(inter_dim, out_dim,
-                                     kernel_size=kernel_size[1], padding=kernel_size[1]//2, stride=1,
-                                     act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        self.shortcut = shortcut and in_dim == out_dim
-
-    def forward(self, x):
-        h = self.conv_layer2(self.conv_layer1(x))
-
-        return x + h if self.shortcut else h
-
-class ResBlock(nn.Module):
-    def __init__(self,
-                 in_dim,
-                 out_dim,
-                 num_blocks   :int   = 1,
-                 expand_ratio :float = 0.5,
-                 shortcut     :bool  = False,
-                 act_type     :str   = 'silu',
-                 norm_type    :str   = 'BN',
-                 depthwise    :bool  = False,
-                 ):
-        super(ResBlock, self).__init__()
-        # ---------- Basic parameters ----------
-        self.num_blocks = num_blocks
-        self.expand_ratio = expand_ratio
-        self.shortcut = shortcut
-        # ---------- Model parameters ----------
-        self.module = nn.Sequential(*[YoloBottleneck(in_dim,
-                                                     out_dim,
-                                                     kernel_size  = [1, 3],
-                                                     expand_ratio = expand_ratio,
-                                                     shortcut     = shortcut,
-                                                     act_type     = act_type,
-                                                     norm_type    = norm_type,
-                                                     depthwise    = depthwise)
-                                                     for _ in range(num_blocks)
-                                                     ])
-
-    def forward(self, x):
-        out = self.module(x)
-
-        return out
-    

+ 0 - 16
iclab/models/elandarknet/build.py

@@ -1,16 +0,0 @@
-from .elandarknet import elandarknet_n, elandarknet_s, elandarknet_m, elandarknet_l, elandarknet_x
-
-def build_elandarknet(args):
-    # build vit model
-    if   args.model == 'elandarknet_n':
-        model = elandarknet_n(args.img_dim, args.num_classes)
-    elif args.model == 'elandarknet_s':
-        model = elandarknet_s(args.img_dim, args.num_classes)
-    elif args.model == 'elandarknet_m':
-        model = elandarknet_m(args.img_dim, args.num_classes)
-    elif args.model == 'elandarknet_l':
-        model = elandarknet_l(args.img_dim, args.num_classes)
-    elif args.model == 'elandarknet_x':
-        model = elandarknet_x(args.img_dim, args.num_classes)
-    
-    return model

+ 0 - 171
iclab/models/elandarknet/elandarknet.py

@@ -1,171 +0,0 @@
-import torch
-import torch.nn as nn
-
-try:
-    from .modules import BasicConv, ELANLayer
-except:
-    from  modules import BasicConv, ELANLayer
-   
-
-## ELAN-based DarkNet
-class ELANDarkNet(nn.Module):
-    def __init__(self, img_dim=3, width=1.0, depth=1.0, ratio=1.0, num_classes=1000, act_type='silu', norm_type='BN', depthwise=False):
-        super(ELANDarkNet, self).__init__()
-        # ---------------- Basic parameters ----------------
-        self.width_factor = width
-        self.depth_factor = depth
-        self.last_stage_factor = ratio
-        self.feat_dims = [round(64 * width),
-                          round(128 * width),
-                          round(256 * width),
-                          round(512 * width),
-                          round(512 * width * ratio)
-                          ]
-        # ---------------- Network parameters ----------------
-        ## P1/2
-        self.layer_1 = BasicConv(img_dim, self.feat_dims[0],
-                                 kernel_size=3, padding=1, stride=2,
-                                 act_type=act_type, norm_type=norm_type)
-        ## P2/4
-        self.layer_2 = nn.Sequential(
-            BasicConv(self.feat_dims[0], self.feat_dims[1],
-                      kernel_size=3, padding=1, stride=2,
-                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            ELANLayer(in_dim     = self.feat_dims[1],
-                      out_dim    = self.feat_dims[1],
-                      num_blocks = round(3*depth),
-                      shortcut   = True,
-                      act_type   = act_type,
-                      norm_type  = norm_type,
-                      depthwise  = depthwise,
-                      )
-        )
-        ## P3/8
-        self.layer_3 = nn.Sequential(
-            BasicConv(self.feat_dims[1], self.feat_dims[2],
-                      kernel_size=3, padding=1, stride=2,
-                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            ELANLayer(in_dim     = self.feat_dims[2],
-                      out_dim    = self.feat_dims[2],
-                      num_blocks = round(6*depth),
-                      shortcut   = True,
-                      act_type   = act_type,
-                      norm_type  = norm_type,
-                      depthwise  = depthwise,
-                      )
-        )
-        ## P4/16
-        self.layer_4 = nn.Sequential(
-            BasicConv(self.feat_dims[2], self.feat_dims[3],
-                      kernel_size=3, padding=1, stride=2,
-                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            ELANLayer(in_dim     = self.feat_dims[3],
-                      out_dim    = self.feat_dims[3],
-                      num_blocks = round(6*depth),
-                      shortcut   = True,
-                      act_type   = act_type,
-                      norm_type  = norm_type,
-                      depthwise  = depthwise,
-                      )
-        )
-        ## P5/32
-        self.layer_5 = nn.Sequential(
-            BasicConv(self.feat_dims[3], self.feat_dims[4],
-                      kernel_size=3, padding=1, stride=2,
-                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            ELANLayer(in_dim     = self.feat_dims[4],
-                      out_dim    = self.feat_dims[4],
-                      num_blocks = round(3*depth),
-                      shortcut   = True,
-                      act_type   = act_type,
-                      norm_type  = norm_type,
-                      depthwise  = depthwise,
-                      )
-        )
-        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
-        self.fc = nn.Linear(self.feat_dims[4], num_classes)
-
-
-    def forward(self, x):
-        c1 = self.layer_1(x)
-        c2 = self.layer_2(c1)
-        c3 = self.layer_3(c2)
-        c4 = self.layer_4(c3)
-        c5 = self.layer_5(c4)
-
-        c5 = self.avgpool(c5)
-        c5 = torch.flatten(c5, 1)
-        c5 = self.fc(c5)
-
-        return c5
-
-
-# ------------------------ Model Functions ------------------------
-def elandarknet_n(img_dim=3, num_classes=1000) -> ELANDarkNet:
-    return ELANDarkNet(img_dim=img_dim,
-                       width=0.25,
-                       depth=0.34,
-                       ratio=2.0,
-                       act_type='silu',
-                       norm_type='BN',
-                       depthwise=False,
-                       num_classes=num_classes
-                       )
-
-def elandarknet_s(img_dim=3, num_classes=1000) -> ELANDarkNet:
-    return ELANDarkNet(img_dim=img_dim,
-                       width=0.50,
-                       depth=0.34,
-                       ratio=2.0,
-                       act_type='silu',
-                       norm_type='BN',
-                       depthwise=False,
-                       num_classes=num_classes
-                       )
-
-def elandarknet_m(img_dim=3, num_classes=1000) -> ELANDarkNet:
-    return ELANDarkNet(img_dim=img_dim,
-                       width=0.75,
-                       depth=0.67,
-                       ratio=1.5,
-                       act_type='silu',
-                       norm_type='BN',
-                       depthwise=False,
-                       num_classes=num_classes
-                       )
-
-def elandarknet_l(img_dim=3, num_classes=1000) -> ELANDarkNet:
-    return ELANDarkNet(img_dim=img_dim,
-                       width=1.0,
-                       depth=1.0,
-                       ratio=1.0,
-                       act_type='silu',
-                       norm_type='BN',
-                       depthwise=False,
-                       num_classes=num_classes
-                       )
-
-def elandarknet_x(img_dim=3, num_classes=1000) -> ELANDarkNet:
-    return ELANDarkNet(img_dim=img_dim,
-                       width=1.25,
-                       depth=1.34,
-                       ratio=1.0,
-                       act_type='silu',
-                       norm_type='BN',
-                       depthwise=False,
-                       num_classes=num_classes
-                       )
-
-
-if __name__ == '__main__':
-    import torch
-    from thop import profile
-
-    # build model
-    model = elandarknet_s()
-
-    x = torch.randn(1, 3, 224, 224)
-    print('==============================')
-    flops, params = profile(model, inputs=(x, ), verbose=False)
-    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))

+ 0 - 135
iclab/models/elandarknet/modules.py

@@ -1,135 +0,0 @@
-import torch
-import torch.nn as nn
-from   typing import List
-
-# --------------------- Basic modules ---------------------
-def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
-    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
-
-    return conv
-
-def get_activation(act_type=None):
-    if act_type == 'relu':
-        return nn.ReLU(inplace=True)
-    elif act_type == 'lrelu':
-        return nn.LeakyReLU(0.1, inplace=True)
-    elif act_type == 'mish':
-        return nn.Mish(inplace=True)
-    elif act_type == 'silu':
-        return nn.SiLU(inplace=True)
-    elif act_type is None:
-        return nn.Identity()
-    else:
-        raise NotImplementedError
-        
-def get_norm(norm_type, dim):
-    if norm_type == 'BN':
-        return nn.BatchNorm2d(dim)
-    elif norm_type == 'GN':
-        return nn.GroupNorm(num_groups=32, num_channels=dim)
-    elif norm_type is None:
-        return nn.Identity()
-    else:
-        raise NotImplementedError
-
-class BasicConv(nn.Module):
-    def __init__(self, 
-                 in_dim,                   # in channels
-                 out_dim,                  # out channels 
-                 kernel_size=1,            # kernel size 
-                 padding=0,                # padding
-                 stride=1,                 # padding
-                 dilation=1,               # dilation
-                 act_type  :str = 'lrelu', # activation
-                 norm_type :str = 'BN',    # normalization
-                 depthwise :bool = False
-                ):
-        super(BasicConv, self).__init__()
-        self.depthwise = depthwise
-        if not depthwise:
-            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=1)
-            self.norm = get_norm(norm_type, out_dim)
-        else:
-            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim)
-            self.norm1 = get_norm(norm_type, in_dim)
-            self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1)
-            self.norm2 = get_norm(norm_type, out_dim)
-        self.act  = get_activation(act_type)
-
-    def forward(self, x):
-        if not self.depthwise:
-            return self.act(self.norm(self.conv(x)))
-        else:
-            # Depthwise conv
-            x = self.norm1(self.conv1(x))
-            # Pointwise conv
-            x = self.act(self.norm2(self.conv2(x)))
-            return x
-
-
-# --------------------- Yolov8 modules ---------------------
-class Bottleneck(nn.Module):
-    def __init__(self,
-                 in_dim       :int,
-                 out_dim      :int,
-                 kernel_size  :List  = [1, 3],
-                 expand_ratio :float = 0.5,
-                 shortcut     :bool  = False,
-                 act_type     :str   = 'silu',
-                 norm_type    :str   = 'BN',
-                 depthwise    :bool  = False,
-                 ) -> None:
-        super(Bottleneck, self).__init__()
-        inter_dim = int(out_dim * expand_ratio)
-        # ----------------- Network setting -----------------
-        self.conv_layer1 = BasicConv(in_dim, inter_dim,
-                                     kernel_size=kernel_size[0], padding=kernel_size[0]//2, stride=1,
-                                     act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        self.conv_layer2 = BasicConv(inter_dim, out_dim,
-                                     kernel_size=kernel_size[1], padding=kernel_size[1]//2, stride=1,
-                                     act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        self.shortcut = shortcut and in_dim == out_dim
-
-    def forward(self, x):
-        h = self.conv_layer2(self.conv_layer1(x))
-
-        return x + h if self.shortcut else h
-
-class ELANLayer(nn.Module):
-    def __init__(self,
-                 in_dim,
-                 out_dim,
-                 expand_ratio :float = 0.5,
-                 num_blocks   :int   = 1,
-                 shortcut     :bool  = False,
-                 act_type     :str   = 'silu',
-                 norm_type    :str   = 'BN',
-                 depthwise    :bool  = False,
-                 ) -> None:
-        super(ELANLayer, self).__init__()
-        self.inter_dim = round(out_dim * expand_ratio)
-        self.input_proj  = BasicConv(in_dim, self.inter_dim * 2, kernel_size=1, act_type=act_type, norm_type=norm_type)
-        self.output_proj = BasicConv((2 + num_blocks) * self.inter_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
-        self.module = nn.ModuleList([Bottleneck(self.inter_dim,
-                                                    self.inter_dim,
-                                                    kernel_size  = [3, 3],
-                                                    expand_ratio = 1.0,
-                                                    shortcut     = shortcut,
-                                                    act_type     = act_type,
-                                                    norm_type    = norm_type,
-                                                    depthwise    = depthwise)
-                                                    for _ in range(num_blocks)])
-
-    def forward(self, x):
-        # Input proj
-        x1, x2 = torch.chunk(self.input_proj(x), 2, dim=1)
-        out = list([x1, x2])
-
-        # Bottlenecl
-        out.extend(m(out[-1]) for m in self.module)
-
-        # Output proj
-        out = self.output_proj(torch.cat(out, dim=1))
-
-        return out
-   

+ 0 - 12
iclab/models/gelan/build.py

@@ -1,12 +0,0 @@
-from .gelan import gelan_s, gelan_c
-
-def build_gelan(args):
-    # build vit model
-    if   args.model == 'gelan_s':
-        model = gelan_s(args.img_dim, args.num_classes)
-    elif args.model == 'gelan_c':
-        model = gelan_c(args.img_dim, args.num_classes)
-    else:
-        raise NotImplementedError("Unknown elannet: {}".format(args.model))
-    
-    return model

+ 0 - 233
iclab/models/gelan/gelan.py

@@ -1,233 +0,0 @@
-import torch
-import torch.nn as nn
-
-try:
-    from .modules import BasicConv, RepGElanLayer, ADown
-except:
-    from  modules import BasicConv, RepGElanLayer, ADown
-
-
-# ---------------------------- GELAN Backbone ----------------------------
-class GElanCBackbone(nn.Module):
-    def __init__(self, img_dim=3, num_classes=1000, act_type='silu', norm_type='BN', depthwise=False):
-        super(GElanCBackbone, self).__init__()
-        # ------------------ Basic setting ------------------
-        self.feat_dims = {
-            "c1": [64],
-            "c2": [128, [128, 64],  256],
-            "c3": [256, [256, 128], 512],
-            "c4": [512, [512, 256], 512],
-            "c5": [512, [512, 256], 512],
-        }
-        
-        # ------------------ Network setting ------------------
-        ## P1/2
-        self.layer_1 = BasicConv(img_dim, self.feat_dims["c1"][0],
-                                 kernel_size=3, padding=1, stride=2,
-                                 act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        # P2/4
-        self.layer_2 = nn.Sequential(
-            BasicConv(self.feat_dims["c1"][0], self.feat_dims["c2"][0],
-                      kernel_size=3, padding=1, stride=2,
-                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            RepGElanLayer(in_dim     = self.feat_dims["c2"][0],
-                          inter_dims = self.feat_dims["c2"][1],
-                          out_dim    = self.feat_dims["c2"][2],
-                          num_blocks = 1,
-                          shortcut   = True,
-                          act_type   = act_type,
-                          norm_type  = norm_type,
-                          depthwise  = depthwise)
-        )
-        # P3/8
-        self.layer_3 = nn.Sequential(
-            ADown(self.feat_dims["c2"][2], self.feat_dims["c3"][0],
-                  act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            RepGElanLayer(in_dim     = self.feat_dims["c3"][0],
-                          inter_dims = self.feat_dims["c3"][1],
-                          out_dim    = self.feat_dims["c3"][2],
-                          num_blocks = 1,
-                          shortcut   = True,
-                          act_type   = act_type,
-                          norm_type  = norm_type,
-                          depthwise  = depthwise)
-        )
-        # P4/16
-        self.layer_4 = nn.Sequential(
-            ADown(self.feat_dims["c3"][2], self.feat_dims["c4"][0],
-                  act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            RepGElanLayer(in_dim     = self.feat_dims["c4"][0],
-                          inter_dims = self.feat_dims["c4"][1],
-                          out_dim    = self.feat_dims["c4"][2],
-                          num_blocks = 1,
-                          shortcut   = True,
-                          act_type   = act_type,
-                          norm_type  = norm_type,
-                          depthwise  = depthwise)
-        )
-        # P5/32
-        self.layer_5 = nn.Sequential(
-            ADown(self.feat_dims["c4"][2], self.feat_dims["c5"][0],
-                  act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            RepGElanLayer(in_dim     = self.feat_dims["c5"][0],
-                          inter_dims = self.feat_dims["c5"][1],
-                          out_dim    = self.feat_dims["c5"][2],
-                          num_blocks = 1,
-                          shortcut   = True,
-                          act_type   = act_type,
-                          norm_type  = norm_type,
-                          depthwise  = depthwise)
-        )
-        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
-        self.fc = nn.Linear(self.feat_dims["c5"][2], num_classes)
-
-        # Initialize all layers
-        self.init_weights()
-        
-    def init_weights(self):
-        """Initialize the parameters."""
-        for m in self.modules():
-            if isinstance(m, torch.nn.Conv2d):
-                # In order to be consistent with the source code,
-                # reset the Conv2d initialization parameters
-                m.reset_parameters()
-
-    def forward(self, x):
-        c1 = self.layer_1(x)
-        c2 = self.layer_2(c1)
-        c3 = self.layer_3(c2)
-        c4 = self.layer_4(c3)
-        c5 = self.layer_5(c4)
-
-        c5 = self.avgpool(c5)
-        c5 = torch.flatten(c5, 1)
-        c5 = self.fc(c5)
-
-        return c5
-
-class GElanSBackbone(nn.Module):
-    def __init__(self, img_dim=3, num_classes=1000, act_type='silu', norm_type='BN', depthwise=False):
-        super(GElanSBackbone, self).__init__()
-        # ------------------ Basic setting ------------------
-        self.feat_dims = {
-            "c1": [32],
-            "c2": [64,  [64, 32],   64],
-            "c3": [64,  [64, 32],   128],
-            "c4": [128, [128, 64],  256],
-            "c5": [256, [256, 128], 256],
-        }
-        
-        # ------------------ Network setting ------------------
-        ## P1/2
-        self.layer_1 = BasicConv(img_dim, self.feat_dims["c1"][0],
-                                 kernel_size=3, padding=1, stride=2,
-                                 act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        # P2/4
-        self.layer_2 = nn.Sequential(
-            BasicConv(self.feat_dims["c1"][0], self.feat_dims["c2"][0],
-                      kernel_size=3, padding=1, stride=2,
-                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            RepGElanLayer(in_dim     = self.feat_dims["c2"][0],
-                          inter_dims = self.feat_dims["c2"][1],
-                          out_dim    = self.feat_dims["c2"][2],
-                          num_blocks = 3,
-                          shortcut   = True,
-                          act_type   = act_type,
-                          norm_type  = norm_type,
-                          depthwise  = depthwise)
-        )
-        # P3/8
-        self.layer_3 = nn.Sequential(
-            ADown(self.feat_dims["c2"][2], self.feat_dims["c3"][0],
-                  act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            RepGElanLayer(in_dim     = self.feat_dims["c3"][0],
-                          inter_dims = self.feat_dims["c3"][1],
-                          out_dim    = self.feat_dims["c3"][2],
-                          num_blocks = 3,
-                          shortcut   = True,
-                          act_type   = act_type,
-                          norm_type  = norm_type,
-                          depthwise  = depthwise)
-        )
-        # P4/16
-        self.layer_4 = nn.Sequential(
-            ADown(self.feat_dims["c3"][2], self.feat_dims["c4"][0],
-                  act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            RepGElanLayer(in_dim     = self.feat_dims["c4"][0],
-                          inter_dims = self.feat_dims["c4"][1],
-                          out_dim    = self.feat_dims["c4"][2],
-                          num_blocks = 3,
-                          shortcut   = True,
-                          act_type   = act_type,
-                          norm_type  = norm_type,
-                          depthwise  = depthwise)
-        )
-        # P5/32
-        self.layer_5 = nn.Sequential(
-            ADown(self.feat_dims["c4"][2], self.feat_dims["c5"][0],
-                  act_type=act_type, norm_type=norm_type, depthwise=depthwise),
-            RepGElanLayer(in_dim     = self.feat_dims["c5"][0],
-                          inter_dims = self.feat_dims["c5"][1],
-                          out_dim    = self.feat_dims["c5"][2],
-                          num_blocks = 3,
-                          shortcut   = True,
-                          act_type   = act_type,
-                          norm_type  = norm_type,
-                          depthwise  = depthwise)
-        )
-        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
-        self.fc = nn.Linear(self.feat_dims["c5"][2], num_classes)
-
-        # Initialize all layers
-        self.init_weights()
-        
-    def init_weights(self):
-        """Initialize the parameters."""
-        for m in self.modules():
-            if isinstance(m, torch.nn.Conv2d):
-                # In order to be consistent with the source code,
-                # reset the Conv2d initialization parameters
-                m.reset_parameters()
-
-    def forward(self, x):
-        c1 = self.layer_1(x)
-        c2 = self.layer_2(c1)
-        c3 = self.layer_3(c2)
-        c4 = self.layer_4(c3)
-        c5 = self.layer_5(c4)
-
-        c5 = self.avgpool(c5)
-        c5 = torch.flatten(c5, 1)
-        c5 = self.fc(c5)
-
-        return c5
-
-
-# ---------------------------- Functions ----------------------------
-def gelan_c(img_dim=3, num_classes=1000):
-    return GElanCBackbone(img_dim,
-                          num_classes=num_classes,
-                          act_type='silu',
-                          norm_type='BN',
-                          depthwise=False)
-
-def gelan_s(img_dim=3, num_classes=1000):
-    return GElanSBackbone(img_dim,
-                          num_classes=num_classes,
-                          act_type='silu',
-                          norm_type='BN',
-                          depthwise=False)
-
-
-if __name__ == '__main__':
-    import torch
-    from thop import profile
-
-    # build model
-    model = gelan_c()
-
-    x = torch.randn(1, 3, 224, 224)
-    print('==============================')
-    flops, params = profile(model, inputs=(x, ), verbose=False)
-    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))

+ 0 - 312
iclab/models/gelan/modules.py

@@ -1,312 +0,0 @@
-import numpy as np
-import torch
-import torch.nn as nn
-from typing import List
-
-
-# --------------------- Basic modules ---------------------
-def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
-    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
-
-    return conv
-
-def get_activation(act_type=None):
-    if act_type == 'relu':
-        return nn.ReLU(inplace=True)
-    elif act_type == 'lrelu':
-        return nn.LeakyReLU(0.1, inplace=True)
-    elif act_type == 'mish':
-        return nn.Mish(inplace=True)
-    elif act_type == 'silu':
-        return nn.SiLU(inplace=True)
-    elif act_type is None:
-        return nn.Identity()
-    else:
-        raise NotImplementedError
-        
-def get_norm(norm_type, dim):
-    if norm_type == 'BN':
-        return nn.BatchNorm2d(dim)
-    elif norm_type == 'GN':
-        return nn.GroupNorm(num_groups=32, num_channels=dim)
-    elif norm_type is None:
-        return nn.Identity()
-    else:
-        raise NotImplementedError
-
-class BasicConv(nn.Module):
-    def __init__(self, 
-                 in_dim,                   # in channels
-                 out_dim,                  # out channels 
-                 kernel_size=1,            # kernel size 
-                 padding=0,                # padding
-                 stride=1,                 # padding
-                 dilation=1,               # dilation
-                 group=1,                  # group
-                 act_type  :str = 'lrelu', # activation
-                 norm_type :str = 'BN',    # normalization
-                 depthwise :bool = False
-                ):
-        super(BasicConv, self).__init__()
-        self.depthwise = depthwise
-        if not depthwise:
-            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=group)
-            self.norm = get_norm(norm_type, out_dim)
-        else:
-            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim)
-            self.norm1 = get_norm(norm_type, in_dim)
-            self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1)
-            self.norm2 = get_norm(norm_type, out_dim)
-        self.act  = get_activation(act_type)
-
-    def forward(self, x):
-        if not self.depthwise:
-            return self.act(self.norm(self.conv(x)))
-        else:
-            # Depthwise conv
-            x = self.norm1(self.conv1(x))
-            # Pointwise conv
-            x = self.act(self.norm2(self.conv2(x)))
-            return x
-
-
-# --------------------- GELAN modules (from yolov9) ---------------------
-class ADown(nn.Module):
-    def __init__(self, in_dim, out_dim, act_type="silu", norm_type="BN", depthwise=False):
-        super().__init__()
-        inter_dim = out_dim // 2
-        self.conv_layer_1 = BasicConv(in_dim // 2, inter_dim,
-                                    kernel_size=3, padding=1, stride=2,
-                                    act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        self.conv_layer_2 = BasicConv(in_dim // 2, inter_dim, kernel_size=1,
-                                    act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-    def forward(self, x):
-        x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
-        x1,x2 = x.chunk(2, 1)
-        x1 = self.conv_layer_1(x1)
-        x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1)
-        x2 = self.conv_layer_2(x2)
-
-        return torch.cat((x1, x2), 1)
-
-class RepConvN(nn.Module):
-    """RepConv is a basic rep-style block, including training and deploy status
-    This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
-    """
-    def __init__(self, in_dim, out_dim, k=3, s=1, p=1, g=1, act_type='silu', norm_type='BN', depthwise=False):
-        super().__init__()
-        assert k == 3 and p == 1
-        self.g = g
-        self.in_dim = in_dim
-        self.out_dim = out_dim
-        self.act = get_activation(act_type)
-
-        self.bn = None
-        self.conv1 = BasicConv(in_dim, out_dim,
-                               kernel_size=k, padding=p, stride=s, group=g,
-                               act_type=None, norm_type=norm_type, depthwise=depthwise)
-        self.conv2 = BasicConv(in_dim, out_dim,
-                               kernel_size=1, padding=(p - k // 2), stride=s, group=g,
-                               act_type=None, norm_type=norm_type, depthwise=depthwise)
-
-    def forward(self, x):
-        """Forward process"""
-        if hasattr(self, 'conv'):
-            return self.forward_fuse(x)
-        else:
-            id_out = 0 if self.bn is None else self.bn(x)
-            return self.act(self.conv1(x) + self.conv2(x) + id_out)
-
-    def forward_fuse(self, x):
-        """Forward process"""
-        return self.act(self.conv(x))
-
-    def get_equivalent_kernel_bias(self):
-        kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
-        kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
-        kernelid, biasid = self._fuse_bn_tensor(self.bn)
-        return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
-
-    def _avg_to_3x3_tensor(self, avgp):
-        channels = self.in_dim
-        groups = self.g
-        kernel_size = avgp.kernel_size
-        input_dim = channels // groups
-        k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
-        k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
-        return k
-
-    def _pad_1x1_to_3x3_tensor(self, kernel1x1):
-        if kernel1x1 is None:
-            return 0
-        else:
-            return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
-
-    def _fuse_bn_tensor(self, branch):
-        if branch is None:
-            return 0, 0
-        if isinstance(branch, BasicConv):
-            kernel       = branch.conv.weight
-            running_mean = branch.norm.running_mean
-            running_var  = branch.norm.running_var
-            gamma        = branch.norm.weight
-            beta         = branch.norm.bias
-            eps          = branch.norm.eps
-        elif isinstance(branch, nn.BatchNorm2d):
-            if not hasattr(self, 'id_tensor'):
-                input_dim = self.in_dim // self.g
-                kernel_value = np.zeros((self.in_dim, input_dim, 3, 3), dtype=np.float32)
-                for i in range(self.in_dim):
-                    kernel_value[i, i % input_dim, 1, 1] = 1
-                self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
-            kernel       = self.id_tensor
-            running_mean = branch.running_mean
-            running_var  = branch.running_var
-            gamma        = branch.weight
-            beta         = branch.bias
-            eps          = branch.eps
-        std = (running_var + eps).sqrt()
-        t = (gamma / std).reshape(-1, 1, 1, 1)
-        return kernel * t, beta - running_mean * gamma / std
-
-    def fuse_convs(self):
-        if hasattr(self, 'conv'):
-            return
-        kernel, bias = self.get_equivalent_kernel_bias()
-        self.conv = nn.Conv2d(in_channels  = self.conv1.conv.in_channels,
-                              out_channels = self.conv1.conv.out_channels,
-                              kernel_size  = self.conv1.conv.kernel_size,
-                              stride       = self.conv1.conv.stride,
-                              padding      = self.conv1.conv.padding,
-                              dilation     = self.conv1.conv.dilation,
-                              groups       = self.conv1.conv.groups,
-                              bias         = True).requires_grad_(False)
-        self.conv.weight.data = kernel
-        self.conv.bias.data = bias
-        for para in self.parameters():
-            para.detach_()
-        self.__delattr__('conv1')
-        self.__delattr__('conv2')
-        if hasattr(self, 'nm'):
-            self.__delattr__('nm')
-        if hasattr(self, 'bn'):
-            self.__delattr__('bn')
-        if hasattr(self, 'id_tensor'):
-            self.__delattr__('id_tensor')
-
-class RepNBottleneck(nn.Module):
-    def __init__(self,
-                 in_dim,
-                 out_dim,
-                 shortcut=True,
-                 kernel_size=(3, 3),
-                 expansion=0.5,
-                 act_type='silu',
-                 norm_type='BN',
-                 depthwise=False
-                 ):
-        super().__init__()
-        inter_dim = round(out_dim * expansion)
-        self.conv_layer_1 = RepConvN(in_dim, inter_dim, kernel_size[0], p=kernel_size[0]//2, s=1, act_type=act_type, norm_type=norm_type)
-        self.conv_layer_2 = BasicConv(inter_dim, out_dim, kernel_size[1], padding=kernel_size[1]//2, stride=1, act_type=act_type, norm_type=norm_type)
-        self.add = shortcut and in_dim == out_dim
-
-    def forward(self, x):
-        h = self.conv_layer_2(self.conv_layer_1(x))
-        return x + h if self.add else h
-
-class RepNCSP(nn.Module):
-    def __init__(self,
-                 in_dim,
-                 out_dim,
-                 num_blocks=1,
-                 shortcut=True,
-                 expansion=0.5,
-                 act_type='silu',
-                 norm_type='BN',
-                 depthwise=False
-                 ):
-        super().__init__()
-        inter_dim = int(out_dim * expansion)
-        self.conv_layer_1 = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
-        self.conv_layer_2 = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
-        self.conv_layer_3 = BasicConv(2 * inter_dim, out_dim, kernel_size=1)
-        self.module       = nn.Sequential(*(RepNBottleneck(inter_dim,
-                                                           inter_dim,
-                                                           kernel_size = [3, 3],
-                                                           shortcut    = shortcut,
-                                                           expansion   = 1.0,
-                                                           act_type    = act_type,
-                                                           norm_type   = norm_type,
-                                                           depthwise   = depthwise)
-                                                           for _ in range(num_blocks)))
-
-    def forward(self, x):
-        x1 = self.conv_layer_1(x)
-        x2 = self.module(self.conv_layer_2(x))
-
-        return self.conv_layer_3(torch.cat([x1, x2], dim=1))
-
-class RepGElanLayer(nn.Module):
-    """YOLOv9's GELAN module"""
-    def __init__(self,
-                 in_dim     :int,
-                 inter_dims :List,
-                 out_dim    :int,
-                 num_blocks :int   = 1,
-                 shortcut   :bool  = False,
-                 act_type   :str   = 'silu',
-                 norm_type  :str   = 'BN',
-                 depthwise  :bool  = False,
-                 ) -> None:
-        super(RepGElanLayer, self).__init__()
-        # ----------- Basic parameters -----------
-        self.in_dim = in_dim
-        self.inter_dims = inter_dims
-        self.out_dim = out_dim
-
-        # ----------- Network parameters -----------
-        self.conv_layer_1  = BasicConv(in_dim, inter_dims[0], kernel_size=1, act_type=act_type, norm_type=norm_type)
-        self.elan_module_1 = nn.Sequential(
-             RepNCSP(inter_dims[0]//2,
-                     inter_dims[1],
-                     num_blocks  = num_blocks,
-                     shortcut    = shortcut,
-                     expansion   = 0.5,
-                     act_type    = act_type,
-                     norm_type   = norm_type,
-                     depthwise   = depthwise),
-            BasicConv(inter_dims[1], inter_dims[1],
-                      kernel_size=3, padding=1,
-                      act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-        self.elan_module_2 = nn.Sequential(
-             RepNCSP(inter_dims[1],
-                     inter_dims[1],
-                     num_blocks  = num_blocks,
-                     shortcut    = shortcut,
-                     expansion   = 0.5,
-                     act_type    = act_type,
-                     norm_type   = norm_type,
-                     depthwise   = depthwise),
-            BasicConv(inter_dims[1], inter_dims[1],
-                      kernel_size=3, padding=1,
-                      act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-        self.conv_layer_2 = BasicConv(inter_dims[0] + 2*self.inter_dims[1], out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
-
-
-    def forward(self, x):
-        # Input proj
-        x1, x2 = torch.chunk(self.conv_layer_1(x), 2, dim=1)
-        out = list([x1, x2])
-
-        # ELAN module
-        out.append(self.elan_module_1(out[-1]))
-        out.append(self.elan_module_2(out[-1]))
-
-        # Output proj
-        out = self.conv_layer_2(torch.cat(out, dim=1))
-
-        return out
-    

+ 0 - 332
iclab/train.py

@@ -1,332 +0,0 @@
-from copy import deepcopy
-import os
-import time
-import math
-import argparse
-import datetime
-
-# ---------------- Timm compoments ----------------
-from timm.data.mixup import Mixup
-from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
-
-# ---------------- Torch compoments ----------------
-import torch
-import torch.backends.cudnn as cudnn
-import torch.distributed as dist
-from torch.nn.parallel import DistributedDataParallel as DDP
-
-# ---------------- Dataset compoments ----------------
-from data import build_dataset, build_dataloader
-
-# ---------------- Model compoments ----------------
-from models import build_model
-
-# ---------------- Utils compoments ----------------
-from utils import distributed_utils
-from utils.ema import ModelEMA
-from utils.misc import setup_seed, print_rank_0, load_model, save_model
-from utils.misc import NativeScalerWithGradNormCount as NativeScaler
-from utils.optimzer import build_optimizer
-from utils.lr_scheduler import build_lr_scheduler, LinearWarmUpLrScheduler
-from utils.com_flops_params import FLOPs_and_Params
-
-# ---------------- Training engine ----------------
-from engine import train_one_epoch, evaluate
-
-
-def parse_args():
-    parser = argparse.ArgumentParser()
-    # Input
-    parser.add_argument('--img_size', type=int, default=224,
-                        help='input image size.')    
-    parser.add_argument('--img_dim', type=int, default=3,
-                        help='3 for RGB; 1 for Gray.')    
-    parser.add_argument('--num_classes', type=int, default=1000,
-                        help='Number of the classes.')    
-    # Basic
-    parser.add_argument('--seed', type=int, default=42,
-                        help='random seed.')
-    parser.add_argument('--cuda', action='store_true', default=False,
-                        help='use cuda')
-    parser.add_argument('--batch_size', type=int, default=256,
-                        help='batch size on all GPUs')
-    parser.add_argument('--num_workers', type=int, default=4,
-                        help='number of workers')
-    parser.add_argument('--path_to_save', type=str, default='weights/',
-                        help='path to save trained model.')
-    parser.add_argument('--tfboard', action='store_true', default=False,
-                        help='use tensorboard')
-    parser.add_argument('--eval', action='store_true', default=False,
-                        help='evaluate model.')
-    # Epoch
-    parser.add_argument('--wp_epoch', type=int, default=20, 
-                        help='warmup epoch for finetune with MAE pretrained')
-    parser.add_argument('--start_epoch', type=int, default=0, 
-                        help='start epoch for finetune with MAE pretrained')
-    parser.add_argument('--max_epoch', type=int, default=300, 
-                        help='max epoch')
-    parser.add_argument('--eval_epoch', type=int, default=10, 
-                        help='max epoch')
-    # Dataset
-    parser.add_argument('--dataset', type=str, default='cifar10',
-                        help='dataset name')
-    parser.add_argument('--root', type=str, default='/mnt/share/ssd2/dataset',
-                        help='path to dataset folder')
-    # Model
-    parser.add_argument('-m', '--model', type=str, default='rtcnet_n',
-                        help='model name')
-    parser.add_argument('--resume', default=None, type=str,
-                        help='keep training')
-    parser.add_argument('--ema', action='store_true', default=False,
-                        help='use ema.')
-    parser.add_argument('--drop_path', type=float, default=0.1,
-                        help='drop_path')
-    # Optimizer
-    parser.add_argument('-opt', '--optimizer', type=str, default='adamw',
-                        help='sgd, adam')
-    parser.add_argument('-lrs', '--lr_scheduler', type=str, default='step',
-                        help='cosine, step')
-    parser.add_argument('-mt', '--momentum', type=float, default=0.9,
-                        help='weight decay')
-    parser.add_argument('-wd', '--weight_decay', type=float, default=0.05,
-                        help='weight decay')
-    parser.add_argument('--batch_base', type=int, default=256,
-                        help='gradient accumulation')
-    parser.add_argument('--base_lr', type=float, default=1e-3,
-                        help='learning rate for training model')
-    parser.add_argument('--min_lr', type=float, default=1e-6,
-                        help='the final lr')
-    parser.add_argument('--grad_accumulate', type=int, default=1,
-                        help='gradient accumulation')
-    parser.add_argument('--max_grad_norm', type=float, default=None,
-                        help='Clip gradient norm (default: None, no clipping)')
-    # Augmentation parameters
-    parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
-                        help='Color jitter factor (enabled only when not using Auto/RandAug)')
-    parser.add_argument('--aa', type=str, default=None, metavar='NAME',
-                        help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
-    parser.add_argument('--smoothing', type=float, default=0.1,
-                        help='Label smoothing (default: 0.1)')
-    # Random Erase params
-    parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
-                        help='Random erase prob (default: 0.25)')
-    parser.add_argument('--remode', type=str, default='pixel',
-                        help='Random erase mode (default: "pixel")')
-    parser.add_argument('--recount', type=int, default=1,
-                        help='Random erase count (default: 1)')
-    parser.add_argument('--resplit', action='store_true', default=False,
-                        help='Do not random erase first (clean) augmentation split')
-    # Mixup params
-    parser.add_argument('--mixup', type=float, default=0,
-                        help='mixup alpha, mixup enabled if > 0.')
-    parser.add_argument('--cutmix', type=float, default=0,
-                        help='cutmix alpha, cutmix enabled if > 0.')
-    parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
-                        help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
-    parser.add_argument('--mixup_prob', type=float, default=1.0,
-                        help='Probability of performing mixup or cutmix when either/both is enabled')
-    parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
-                        help='Probability of switching to cutmix when both mixup and cutmix enabled')
-    parser.add_argument('--mixup_mode', type=str, default='batch',
-                        help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
-    # DDP
-    parser.add_argument('-dist', '--distributed', action='store_true', default=False,
-                        help='distributed training')
-    parser.add_argument('--dist_url', default='env://', 
-                        help='url used to set up distributed training')
-    parser.add_argument('--world_size', default=1, type=int,
-                        help='number of distributed processes')
-    parser.add_argument('--sybn', action='store_true', default=False, 
-                        help='use sybn.')
-    parser.add_argument('--local_rank', default=-1, type=int,
-                        help='the number of local rank.')
-
-    return parser.parse_args()
-
-    
-def main():
-    args = parse_args()
-    # set random seed
-    setup_seed(args.seed)
-
-    # Path to save model
-    path_to_save = os.path.join(args.path_to_save, args.dataset, args.model)
-    os.makedirs(path_to_save, exist_ok=True)
-    args.output_dir = path_to_save
-    
-    # ------------------------- Build DDP environment -------------------------
-    ## LOCAL_RANK is the global GPU number tag, the value range is [0, world_size - 1].
-    ## LOCAL_PROCESS_RANK is the number of the GPU of each machine, not global.
-    local_rank = local_process_rank = -1
-    if args.distributed:
-        distributed_utils.init_distributed_mode(args)
-        print("git:\n  {}\n".format(distributed_utils.get_sha()))
-        try:
-            # Multiple Mechine & Multiple GPUs (world size > 8)
-            local_rank = torch.distributed.get_rank()
-            local_process_rank = int(os.getenv('LOCAL_PROCESS_RANK', '0'))
-        except:
-            # Single Mechine & Multiple GPUs (world size <= 8)
-            local_rank = local_process_rank = torch.distributed.get_rank()
-    print_rank_0(args)
-    args.world_size = distributed_utils.get_world_size()
-    print('World size: {}'.format(distributed_utils.get_world_size()))
-    print("LOCAL RANK: ", local_rank)
-    print("LOCAL_PROCESS_RANL: ", local_process_rank)
-
-    # ------------------------- Build CUDA -------------------------
-    if args.cuda:
-        if torch.cuda.is_available():
-            cudnn.benchmark = True
-            device = torch.device("cuda")
-        else:
-            print('There is no available GPU.')
-            args.cuda = False
-            device = torch.device("cpu")
-    else:
-        device = torch.device("cpu")
-
-    # ------------------------- Build Tensorboard -------------------------
-    tblogger = None
-    if local_rank <= 0 and args.tfboard:
-        print('use tensorboard')
-        from torch.utils.tensorboard import SummaryWriter
-        time_stamp = time.strftime('%Y-%m-%d_%H:%M:%S',time.localtime(time.time()))
-        log_path = os.path.join('log/', args.dataset, time_stamp)
-        os.makedirs(log_path, exist_ok=True)
-        tblogger = SummaryWriter(log_path)
-
-    # ------------------------- Build Dataset -------------------------
-    train_dataset = build_dataset(args, is_train=True)
-    val_dataset   = build_dataset(args, is_train=False)
-
-    # ------------------------- Build Dataloader -------------------------
-    train_dataloader = build_dataloader(args, train_dataset, is_train=True)
-    val_dataloader   = build_dataloader(args, val_dataset,   is_train=False)
-
-    print('=================== Dataset Information ===================')
-    print("Dataset: ", args.dataset)
-    print('- train dataset size : ', len(train_dataset))
-    print('- val dataset size   : ', len(val_dataset))
-
-    # ------------------------- Build Model -------------------------
-    model = build_model(args)
-    model.train().to(device)
-    print(model)
-    if local_rank <= 0:
-        model_copy = deepcopy(model)
-        model_copy.eval()
-        FLOPs_and_Params(model_copy, args.img_size, args.img_dim, device)
-        model_copy.train()
-        del model_copy
-    if args.distributed:
-        # wait for all processes to synchronize
-        dist.barrier()
-
-    # ------------------------- Build DDP Model -------------------------
-    model_without_ddp = model
-    if args.distributed:
-        model = DDP(model, device_ids=[args.gpu])
-        if args.sybn:
-            print('use SyncBatchNorm ...')
-            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
-        model_without_ddp = model.module
-
-    # ------------------------- Mixup augmentation config -------------------------
-    mixup_fn = None
-    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
-    if mixup_active:
-        print_rank_0("Mixup is activated!", local_rank)
-        mixup_fn = Mixup(mixup_alpha     = args.mixup,
-                         cutmix_alpha    = args.cutmix,
-                         cutmix_minmax   = args.cutmix_minmax,
-                         prob            = args.mixup_prob,
-                         switch_prob     = args.mixup_switch_prob,
-                         mode            = args.mixup_mode,
-                         label_smoothing = args.smoothing,
-                         num_classes     = args.num_classes)
-
-
-    # ------------------------- Build Optimzier -------------------------
-    optimizer = build_optimizer(args, model_without_ddp)
-    loss_scaler = NativeScaler()
-
-    # ------------------------- Build Lr Scheduler -------------------------
-    lr_scheduler_warmup = LinearWarmUpLrScheduler(args.base_lr, wp_iter=args.wp_epoch * len(train_dataloader))
-    lr_scheduler = build_lr_scheduler(args, optimizer)
-
-    # ------------------------- Build Criterion -------------------------
-    if mixup_fn is not None:
-        # smoothing is handled with mixup label transform
-        criterion = SoftTargetCrossEntropy()
-    elif args.smoothing > 0.:
-        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
-    else:
-        criterion = torch.nn.CrossEntropyLoss()
-    load_model(args=args, model_without_ddp=model_without_ddp,
-               optimizer=optimizer, lr_scheduler=lr_scheduler, loss_scaler=loss_scaler)
-
-    # ------------------------- Build Model-EMA -------------------------
-    if args.ema:
-        print("Build model ema for {}".format(args.model))
-        updates = args.start_epoch * len(train_dataloader) // args.grad_accumulate
-        print("Initialial updates of ModelEMA: {}".format(updates))
-        model_ema = ModelEMA(model_without_ddp, ema_decay=0.999, ema_tau=2000., updates=updates)
-    else:
-        model_ema = None
-
-    # ------------------------- Eval before Train Pipeline -------------------------
-    if args.eval:
-        print('evaluating ...')
-        test_stats = evaluate(val_dataloader, model_without_ddp, device, local_rank)
-        print('Eval Results: [loss: %.2f][acc1: %.2f][acc5 : %.2f]' %
-                (test_stats['loss'], test_stats['acc1'], test_stats['acc5']), flush=True)
-        return
-
-    # ------------------------- Training Pipeline -------------------------
-    start_time = time.time()
-    max_accuracy = -1.0
-    print_rank_0("=============== Start training for {} epochs ===============".format(args.max_epoch), local_rank)
-    for epoch in range(args.start_epoch, args.max_epoch):
-        if args.distributed:
-            train_dataloader.batch_sampler.sampler.set_epoch(epoch)
-
-        # train one epoch
-        train_one_epoch(args, device, model, model_ema, train_dataloader, optimizer, epoch,
-                        lr_scheduler_warmup, loss_scaler, criterion, local_rank, tblogger, mixup_fn)
-
-        # LR scheduler
-        if (epoch + 1) > args.wp_epoch:
-            lr_scheduler.step()
-
-        # Evaluate
-        if local_rank <= 0:
-            model_eval = model_ema.ema if model_ema is not None else model_without_ddp
-            if (epoch % args.eval_epoch) == 0 or (epoch + 1 == args.max_epoch):
-                print_rank_0("Evaluating ...")
-                test_stats = evaluate(val_dataloader, model_eval, device, local_rank)
-                print_rank_0(f"Accuracy of the network on the {len(val_dataset)} test images: {test_stats['acc1']:.1f}%", local_rank)
-                max_accuracy = max(max_accuracy, test_stats["acc1"])
-                print_rank_0(f'Max accuracy: {max_accuracy:.2f}%', local_rank)
-
-                # Save model
-                print('- saving the model after {} epochs ...'.format(epoch))
-                save_model(args=args, model=model_eval, model_without_ddp=model_eval,
-                           optimizer=optimizer, lr_scheduler=lr_scheduler, loss_scaler=loss_scaler, epoch=epoch, acc1=max_accuracy)
-        if args.distributed:
-            dist.barrier()
-
-        if tblogger is not None:
-            tblogger.add_scalar('perf/test_acc1', test_stats['acc1'], epoch)
-            tblogger.add_scalar('perf/test_acc5', test_stats['acc5'], epoch)
-            tblogger.add_scalar('perf/test_loss', test_stats['loss'], epoch)
-        if args.distributed:
-            dist.barrier()
-
-    total_time = time.time() - start_time
-    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
-    print('Training time {}'.format(total_time_str))
-
-
-if __name__ == "__main__":
-    main()

+ 0 - 164
iclab/train.sh

@@ -1,164 +0,0 @@
-# ------------------- Args setting -------------------
-MODEL=$1
-DATASET=$2
-DATASET_ROOT=$3
-WORLD_SIZE=$4
-MASTER_PORT=$5
-RESUME=$6
-
-# ------------------- Training setting -------------------
-## Epoch
-BATCH_SIZE=128
-GRAD_ACCUMULATE=32
-WP_EPOCH=10
-MAX_EPOCH=100
-EVAL_EPOCH=5
-DROP_PATH=0.1
-## Scheduler
-OPTIMIZER="adamw"
-LRSCHEDULER="cosine"
-BASE_LR=1e-3         # 0.1 for SGD; 0.001 for AdamW
-MIN_LR=1e-6
-BATCH_BASE=1024      # 256 for SGD; 1024 for AdamW
-MOMENTUM=0.9
-WEIGHT_DECAY=0.05    # 0.0001 for SGD; 0.05 for AdamW
-
-# ------------------- Dataset config -------------------
-if [[ $DATASET == "mnist" ]]; then
-    IMG_SIZE=28
-    NUM_CLASSES=10
-elif [[ $DATASET == "cifar10" ]]; then
-    IMG_SIZE=32
-    NUM_CLASSES=10
-elif [[ $DATASET == "cifar100" ]]; then
-    IMG_SIZE=32
-    NUM_CLASSES=100
-elif [[ $DATASET == "imagenet_1k" || $DATASET == "imagenet_22k" ]]; then
-    IMG_SIZE=224
-    NUM_CLASSES=1000
-elif [[ $DATASET == "custom" ]]; then
-    IMG_SIZE=224
-    NUM_CLASSES=2
-else
-    echo "Unknown dataset!!"
-    exit 1
-fi
-
-
-# ------------------- Training pipeline -------------------
-if [ $WORLD_SIZE == 1 ]; then
-    python train.py \
-            --cuda \
-            --root ${DATASET_ROOT} \
-            --dataset ${DATASET} \
-            --model ${MODEL} \
-            --resume ${RESUME} \
-            --batch_size ${BATCH_SIZE} \
-            --batch_base ${BATCH_BASE} \
-            --grad_accumulate ${GRAD_ACCUMULATE} \
-            --img_size ${IMG_SIZE} \
-            --drop_path ${DROP_PATH} \
-            --max_epoch ${MAX_EPOCH} \
-            --wp_epoch ${WP_EPOCH} \
-            --eval_epoch ${EVAL_EPOCH} \
-            --optimizer ${OPTIMIZER} \
-            --lr_scheduler ${LRSCHEDULER} \
-            --base_lr ${BASE_LR} \
-            --min_lr ${MIN_LR} \
-            --momentum ${MOMENTUM} \
-            --weight_decay ${WEIGHT_DECAY} \
-            --color_jitter 0.0 \
-            --reprob 0.0 \
-            --mixup 0.0 \
-            --cutmix 0.0
-
-elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
-    python -m torch.distributed.run --nproc_per_node=${WORLD_SIZE} --master_port ${MASTER_PORT} train.py \
-            --cuda \
-            --distributed \
-            --root ${DATASET_ROOT} \
-            --dataset ${DATASET} \
-            --model ${MODEL} \
-            --resume ${RESUME} \
-            --batch_size ${BATCH_SIZE} \
-            --batch_base ${BATCH_BASE} \
-            --grad_accumulate ${GRAD_ACCUMULATE} \
-            --img_size ${IMG_SIZE} \
-            --drop_path ${DROP_PATH} \
-            --max_epoch ${MAX_EPOCH} \
-            --wp_epoch ${WP_EPOCH} \
-            --eval_epoch ${EVAL_EPOCH} \
-            --optimizer ${OPTIMIZER} \
-            --lr_scheduler ${LRSCHEDULER} \
-            --base_lr ${BASE_LR} \
-            --min_lr ${MIN_LR} \
-            --momentum ${MOMENTUM} \
-            --weight_decay ${WEIGHT_DECAY} \
-            --sybn \
-            --color_jitter 0.0 \
-            --reprob 0.0 \
-            --mixup 0.0 \
-            --cutmix 0.0
-else
-    echo "The WORLD_SIZE is set to a value greater than 8, indicating the use of multi-machine \
-          multi-card training mode, which is currently unsupported."
-    exit 1
-fi
-
-
-# # ------------------- Training pipeline with strong augmentations -------------------
-# if [ $WORLD_SIZE == 1 ]; then
-#     python train.py \
-#             --cuda \
-#             --root ${DATASET_ROOT} \
-#             --dataset ${DATASET} \
-#             --model ${MODEL} \
-#             --resume ${RESUME} \
-#             --batch_size ${BATCH_SIZE} \
-#             --batch_base ${BATCH_BASE} \
-#             --grad_accumulate ${GRAD_ACCUMULATE} \
-#             --img_size ${IMG_SIZE} \
-#             --drop_path ${DROP_PATH} \
-#             --max_epoch ${MAX_EPOCH} \
-#             --wp_epoch ${WP_EPOCH} \
-#             --eval_epoch ${EVAL_EPOCH} \
-#             --optimizer ${OPTIMIZER} \
-#             --lr_scheduler ${LRSCHEDULER} \
-#             --base_lr ${BASE_LR} \
-#             --min_lr ${MIN_LR} \
-#             --weight_decay ${WEIGHT_DECAY} \
-#             --aa "rand-m9-mstd0.5-inc1" \
-#             --reprob 0.25 \
-#             --mixup 0.8 \
-#             --cutmix 1.0
-# elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
-#     python -m torch.distributed.run --nproc_per_node=${WORLD_SIZE} --master_port ${MASTER_PORT} train.py \
-#             --cuda \
-#             --distributed \
-#             --root ${DATASET_ROOT} \
-#             --dataset ${DATASET} \
-#             --model ${MODEL} \
-#             --resume ${RESUME} \
-#             --batch_size ${BATCH_SIZE} \
-#             --batch_base ${BATCH_BASE} \
-#             --grad_accumulate ${GRAD_ACCUMULATE} \
-#             --img_size ${IMG_SIZE} \
-#             --drop_path ${DROP_PATH} \
-#             --max_epoch ${MAX_EPOCH} \
-#             --wp_epoch ${WP_EPOCH} \
-#             --eval_epoch ${EVAL_EPOCH} \
-#             --optimizer ${OPTIMIZER} \
-#             --lr_scheduler ${LRSCHEDULER} \
-#             --base_lr ${BASE_LR} \
-#             --min_lr ${MIN_LR} \
-#             --weight_decay ${WEIGHT_DECAY} \
-#             --sybn \
-#             --aa "rand-m9-mstd0.5-inc1" \
-#             --reprob 0.25 \
-#             --mixup 0.8 \
-#             --cutmix 1.0
-# else
-#     echo "The WORLD_SIZE is set to a value greater than 8, indicating the use of multi-machine \
-#           multi-card training mode, which is currently unsupported."
-#     exit 1
-# fi

+ 0 - 18
iclab/utils/com_flops_params.py

@@ -1,18 +0,0 @@
-import torch
-from thop import profile
-
-
-def FLOPs_and_Params(model, size, img_dim, device):
-    x = torch.randn(1, img_dim, size, size).to(device)
-    model.eval()
-
-    flops, params = profile(model, inputs=(x, ))
-    print('=================== FLOPs & Params ===================')
-    print('- GFLOPs : ', flops / 1e9 * 2)
-    print('- Params : ', params / 1e6, ' M')
-
-    model.train()
-
-
-if __name__ == "__main__":
-    pass

+ 0 - 165
iclab/utils/distributed_utils.py

@@ -1,165 +0,0 @@
-# from github: https://github.com/ruinmessi/ASFF/blob/master/utils/distributed_util.py
-
-import torch
-import torch.distributed as dist
-import os
-import subprocess
-import pickle
-
-
-def all_gather(data):
-    """
-    Run all_gather on arbitrary picklable data (not necessarily tensors)
-    Args:
-        data: any picklable object
-    Returns:
-        list[data]: list of data gathered from each rank
-    """
-    world_size = get_world_size()
-    if world_size == 1:
-        return [data]
-
-    # serialized to a Tensor
-    buffer = pickle.dumps(data)
-    storage = torch.ByteStorage.from_buffer(buffer)
-    tensor = torch.ByteTensor(storage).to("cuda")
-
-    # obtain Tensor size of each rank
-    local_size = torch.tensor([tensor.numel()], device="cuda")
-    size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
-    dist.all_gather(size_list, local_size)
-    size_list = [int(size.item()) for size in size_list]
-    max_size = max(size_list)
-
-    # receiving Tensor from all ranks
-    # we pad the tensor because torch all_gather does not support
-    # gathering tensors of different shapes
-    tensor_list = []
-    for _ in size_list:
-        tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
-    if local_size != max_size:
-        padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
-        tensor = torch.cat((tensor, padding), dim=0)
-    dist.all_gather(tensor_list, tensor)
-
-    data_list = []
-    for size, tensor in zip(size_list, tensor_list):
-        buffer = tensor.cpu().numpy().tobytes()[:size]
-        data_list.append(pickle.loads(buffer))
-
-    return data_list
-
-
-def reduce_dict(input_dict, average=True):
-    """
-    Args:
-        input_dict (dict): all the values will be reduced
-        average (bool): whether to do average or sum
-    Reduce the values in the dictionary from all processes so that all processes
-    have the averaged results. Returns a dict with the same fields as
-    input_dict, after reduction.
-    """
-    world_size = get_world_size()
-    if world_size < 2:
-        return input_dict
-    with torch.no_grad():
-        names = []
-        values = []
-        # sort the keys so that they are consistent across processes
-        for k in sorted(input_dict.keys()):
-            names.append(k)
-            values.append(input_dict[k])
-        values = torch.stack(values, dim=0)
-        dist.all_reduce(values)
-        if average:
-            values /= world_size
-        reduced_dict = {k: v for k, v in zip(names, values)}
-    return reduced_dict
-
-
-def get_sha():
-    cwd = os.path.dirname(os.path.abspath(__file__))
-
-    def _run(command):
-        return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
-    sha = 'N/A'
-    diff = "clean"
-    branch = 'N/A'
-    try:
-        sha = _run(['git', 'rev-parse', 'HEAD'])
-        subprocess.check_output(['git', 'diff'], cwd=cwd)
-        diff = _run(['git', 'diff-index', 'HEAD'])
-        diff = "has uncommited changes" if diff else "clean"
-        branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
-    except Exception:
-        pass
-    message = f"sha: {sha}, status: {diff}, branch: {branch}"
-    return message
-
-
-def setup_for_distributed(is_master):
-    """
-    This function disables printing when not in master process
-    """
-    import builtins as __builtin__
-    builtin_print = __builtin__.print
-
-    def print(*args, **kwargs):
-        force = kwargs.pop('force', False)
-        if is_master or force:
-            builtin_print(*args, **kwargs)
-
-    __builtin__.print = print
-
-
-def is_dist_avail_and_initialized():
-    if not dist.is_available():
-        return False
-    if not dist.is_initialized():
-        return False
-    return True
-
-
-def get_world_size():
-    if not is_dist_avail_and_initialized():
-        return 1
-    return dist.get_world_size()
-
-
-def get_rank():
-    if not is_dist_avail_and_initialized():
-        return 0
-    return dist.get_rank()
-
-def is_main_process():
-    return get_rank() == 0
-
-
-def save_on_master(*args, **kwargs):
-    if is_main_process():
-        torch.save(*args, **kwargs)
-
-
-def init_distributed_mode(args):
-    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
-        args.rank = int(os.environ["RANK"])
-        args.world_size = int(os.environ['WORLD_SIZE'])
-        args.gpu = int(os.environ['LOCAL_RANK'])
-    elif 'SLURM_PROCID' in os.environ:
-        args.rank = int(os.environ['SLURM_PROCID'])
-        args.gpu = args.rank % torch.cuda.device_count()
-    else:
-        print('Not using distributed mode')
-        args.distributed = False
-        return
-
-    args.distributed = True
-
-    torch.cuda.set_device(args.gpu)
-    args.dist_backend = 'nccl'
-    print('| distributed init (rank {}): {}'.format(
-        args.rank, args.dist_url), flush=True)
-    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
-                                         world_size=args.world_size, rank=args.rank)
-    torch.distributed.barrier()
-    setup_for_distributed(args.rank == 0)

+ 0 - 52
iclab/utils/ema.py

@@ -1,52 +0,0 @@
-from copy import deepcopy
-import math
-import torch
-import torch.nn as nn
-
-
-# ---------------------------- Model tools ----------------------------
-def is_parallel(model):
-    # Returns True if model is of type DP or DDP
-    return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
-
-## Model EMA
-class ModelEMA(object):
-    def __init__(self, model, ema_decay=0.9999, ema_tau=2000, updates=0):
-        # Create EMA
-        self.ema = deepcopy(self.de_parallel(model)).eval()  # FP32 EMA
-        self.updates = updates  # number of EMA updates
-        self.decay = lambda x: ema_decay * (1 - math.exp(-x / ema_tau))  # decay exponential ramp (to help early epochs)
-        for p in self.ema.parameters():
-            p.requires_grad_(False)
-
-    def is_parallel(self, model):
-        # Returns True if model is of type DP or DDP
-        return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
-
-    def de_parallel(self, model):
-        # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
-        return model.module if self.is_parallel(model) else model
-
-    def copy_attr(self, a, b, include=(), exclude=()):
-        # Copy attributes from b to a, options to only include [...] and to exclude [...]
-        for k, v in b.__dict__.items():
-            if (len(include) and k not in include) or k.startswith('_') or k in exclude:
-                continue
-            else:
-                setattr(a, k, v)
-
-    def update(self, model):
-        # Update EMA parameters
-        self.updates += 1
-        d = self.decay(self.updates)
-
-        msd = self.de_parallel(model).state_dict()  # model state_dict
-        for k, v in self.ema.state_dict().items():
-            if v.dtype.is_floating_point:  # true for FP16 and FP32
-                v *= d
-                v += (1 - d) * msd[k].detach()
-        # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
-
-    def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
-        # Update EMA attributes
-        self.copy_attr(self.ema, model, include, exclude)

+ 0 - 24
iclab/utils/optimzer.py

@@ -1,24 +0,0 @@
-import torch
-
-
-def build_optimizer(args, model):
-    ## learning rate
-    if args.optimizer == "adamw":
-        args.base_lr = args.base_lr / args.batch_base * args.batch_size * args.grad_accumulate    # auto scale lr
-        ## optimizer
-        optimizer = torch.optim.AdamW(model.parameters(), lr=args.base_lr, weight_decay=args.weight_decay)
-    elif args.optimizer == "sgd":
-        args.base_lr = args.base_lr / args.batch_base * args.batch_size * args.grad_accumulate    # auto scale lr
-        ## optimizer
-        optimizer = torch.optim.SGD(model.parameters(), lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
-    else:
-        raise NotImplementedError("Unknown optimizer: {}".format(args.optimizer))
-
-    print("=================== Optimizer information ===================")
-    print("Optimizer: ", args.optimizer)
-    print("- momoentum: ", args.momentum)
-    print("- weight decay: ", args.weight_decay)
-    print('- base lr: ', args.base_lr)
-    print('- min  lr: ', args.min_lr)
-
-    return optimizer

+ 2 - 1
iclab/.gitignore → image_classification/.gitignore

@@ -6,6 +6,7 @@
 *.zip
 weights
 __pycache__
-.vscode
+data/cifar/
 data/cifar_data/
 data/mnist_data/
+vis_results/

+ 27 - 0
image_classification/README.md

@@ -0,0 +1,27 @@
+# General Image Classification Laboratory
+
+## Train
+For example, we are going to train `ConvNet` designed in this repo, so we can use the following command:
+
+```Shell
+cd Vision-Pretraining-Tutorial/image_classification/
+python main.py --cuda \
+               --dataset cifar \
+               --model convnet \
+               --batch_size 256 \
+               --optimizer adamw \
+               --base_lr 1e-3 \
+               --min_lr 1e-6
+```
+
+## Evaluate
+- Evaluate the `top1 & top5` accuracy:
+```Shell
+cd Vision-Pretraining-Tutorial/image_classification/
+python main.py --cuda \
+               --dataset cifar \
+               --model convnet \
+               --batch_size 256 \
+               --eval \
+               --resume path/to/checkpoint
+```

+ 42 - 0
image_classification/data/__init__.py

@@ -0,0 +1,42 @@
+import torch
+
+from .cifar import CifarDataset
+from .mnist import MnistDataset
+from .custom import CustomDataset
+
+
+def build_dataset(args, is_train=False):
+    if args.dataset == 'cifar10':
+        args.img_dim     = 3
+        args.img_size    = 32
+        args.mlp_in_dim  = 32 * 32 * 3
+        args.num_classes = 10
+        args.patch_size  = 4
+        return CifarDataset(is_train)
+    elif args.dataset == 'mnist':
+        args.img_dim     = 1
+        args.img_size    = 28
+        args.mlp_in_dim  = 28 * 28 * 1
+        args.num_classes = 10
+        args.patch_size  = 4
+        return MnistDataset(is_train)
+    elif args.dataset == 'custom':
+        assert args.num_classes is not None and isinstance(args.num_classes, int)
+        args.img_size = 224
+        args.mlp_in_dim = 224 * 224 * 3
+        args.patch_size  = 16
+        return CustomDataset(args, is_train)
+    
+
+def build_dataloader(args, dataset, is_train=False):
+    if is_train:
+        sampler = torch.utils.data.RandomSampler(dataset)
+        batch_sampler_train = torch.utils.data.BatchSampler(
+            sampler, args.batch_size, drop_last=True if is_train else False)
+        dataloader = torch.utils.data.DataLoader(
+            dataset, batch_sampler=batch_sampler_train, num_workers=args.num_workers, pin_memory=True)
+    else:
+        dataloader = torch.utils.data.DataLoader(
+            dataset=dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
+
+    return dataloader

+ 2 - 3
iclab/data/cifar.py → image_classification/data/cifar.py

@@ -6,7 +6,7 @@ from torchvision.datasets import CIFAR10
 
 
 class CifarDataset(data.Dataset):
-    def __init__(self, is_train=False, transform=None):
+    def __init__(self, is_train=False):
         super().__init__()
         # ----------------- basic parameters -----------------
         self.is_train   = is_train
@@ -46,7 +46,7 @@ class CifarDataset(data.Dataset):
 
     def build_transform(self):
         if self.is_train:
-            transforms = T.Compose([T.ToTensor(), T.RandomCrop(size=32, padding=8)])
+            transforms = T.Compose([T.ToTensor(), T.RandomCrop(size=32, padding=4)])
         else:
             transforms = T.Compose([T.ToTensor()])
 
@@ -75,4 +75,3 @@ if __name__ == "__main__":
 
         cv2.imshow('image', image)
         cv2.waitKey(0)
-

+ 17 - 39
iclab/data/custom.py → image_classification/data/custom.py

@@ -1,15 +1,13 @@
 import os
 import PIL
 import numpy as np
-from timm.data import create_transform
 import torch.utils.data as data
 import torchvision.transforms as T
 from torchvision.datasets import ImageFolder
-import torchvision.transforms as transforms
 
 
 class CustomDataset(data.Dataset):
-    def __init__(self, args, is_train=False, transform=None):
+    def __init__(self, args, is_train=False):
         super().__init__()
         # ----------------- basic parameters -----------------
         self.args = args
@@ -21,7 +19,7 @@ class CustomDataset(data.Dataset):
         self.image_set = 'train' if is_train else 'val'
         self.data_path = os.path.join(args.root, self.image_set)
         # ----------------- dataset & transforms -----------------
-        self.transform = transform if transform is not None else self.build_transform(args)
+        self.transform = self.build_transform()
         self.dataset = ImageFolder(root=self.data_path, transform=self.transform)
 
     def __len__(self):
@@ -44,40 +42,27 @@ class CustomDataset(data.Dataset):
 
         return image, target
 
-    def build_transform(self, args):
+    def build_transform(self):
         if self.is_train:
-            transforms = create_transform(input_size    = args.img_size,
-                                          is_training   = True,
-                                          color_jitter  = args.color_jitter,
-                                          auto_augment  = args.aa,
-                                          interpolation = 'bicubic',
-                                          re_prob       = args.reprob,
-                                          re_mode       = args.remode,
-                                          re_count      = args.recount,
-                                          mean          = self.pixel_mean,
-                                          std           = self.pixel_std,
-                                          )
+            transforms = T.Compose([
+                            T.RandomResizedCrop(224),
+                            T.RandomHorizontalFlip(0.5),
+                            T.ToTensor(),
+                            T.Normalize(self.pixel_mean,
+                                        self.pixel_std)])
         else:
-            t = []
-            if args.img_size <= 224:
-                crop_pct = 224 / 256
-            else:
-                crop_pct = 1.0
-            size = int(args.img_size / crop_pct)
-            t.append(
-                T.Resize(size, interpolation=PIL.Image.BICUBIC),  # to maintain same ratio w.r.t. 224 images
-            )
-            t.append(T.CenterCrop(args.img_size))
-            t.append(T.ToTensor())
-            t.append(T.Normalize(self.pixel_mean, self.pixel_std))
-            transforms = T.Compose(t)
+            transforms = T.Compose([
+                T.Resize(256, interpolation=PIL.Image.BICUBIC),
+                T.CenterCrop(224),
+                T.ToTensor(),
+                T.Normalize(self.pixel_mean, self.pixel_std),
+            ])
 
         return transforms
 
 
 if __name__ == "__main__":
     import cv2
-    import torch
     import argparse
     
     parser = argparse.ArgumentParser(description='Custom-Dataset')
@@ -88,19 +73,12 @@ if __name__ == "__main__":
     parser.add_argument('--img_size', default=224, type=int,
                         help='input image size.')
     args = parser.parse_args()
-
-    # Transforms
-    train_transform = transforms.Compose([
-            transforms.RandomResizedCrop(args.img_size, scale=(0.2, 1.0), interpolation=3),  # 3 is bicubic
-            transforms.RandomHorizontalFlip(),
-            transforms.ToTensor(),
-            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
   
     # Dataset
-    dataset = CustomDataset(args, is_train=True, transform=train_transform)  
+    dataset = CustomDataset(args, is_train=True)  
     print('Dataset size: ', len(dataset))
 
-    for i in range(1000):
+    for i in range(len(dataset)):
         image, target = dataset.pull_image(i)
         # to BGR
         image = image[..., (2, 1, 0)]

+ 2 - 2
iclab/data/mnist.py → image_classification/data/mnist.py

@@ -5,7 +5,7 @@ from torchvision.datasets import MNIST
 
 
 class MnistDataset(data.Dataset):
-    def __init__(self, is_train=False, transform=None):
+    def __init__(self, is_train=False):
         super().__init__()
         # ----------------- basic parameters -----------------
         self.is_train   = is_train
@@ -48,7 +48,7 @@ class MnistDataset(data.Dataset):
 
 if __name__ == "__main__":
     import cv2
-
+    
     # dataset
     dataset = MnistDataset(is_train=True)  
     print('Dataset size: ', len(dataset))

+ 24 - 48
iclab/engine.py → image_classification/engine.py

@@ -1,26 +1,20 @@
 import sys
 import math
-import numpy as np
-
 import torch
 
 from utils.misc import MetricLogger, SmoothedValue
-from utils.misc import print_rank_0, all_reduce_mean, accuracy
+from utils.misc import accuracy
 
 
 def train_one_epoch(args,
                     device,
                     model,
-                    model_ema,
                     data_loader,
                     optimizer,
                     epoch,
                     lr_scheduler_warmup,
-                    loss_scaler,
                     criterion,
-                    local_rank=0,
-                    tblogger=None,
-                    mixup_fn=None):
+                    ):
     model.train(True)
     metric_logger = MetricLogger(delimiter="  ")
     metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
@@ -45,14 +39,11 @@ def train_one_epoch(args,
         images = images.to(device, non_blocking=True)
         targets = targets.to(device, non_blocking=True)
 
-        # Mixup
-        if mixup_fn is not None:
-            images, targets = mixup_fn(images, targets)
-
         # Inference
-        with torch.cuda.amp.autocast():
-            output = model(images)
-            loss = criterion(output, targets)
+        output = model(images)
+
+        # Compute loss
+        loss = criterion(output, targets)
 
         # Check loss
         loss_value = loss.item()
@@ -60,61 +51,47 @@ def train_one_epoch(args,
             print("Loss is {}, stopping training".format(loss_value))
             sys.exit(1)
 
-        # Backward & Optimize
-        loss /= args.grad_accumulate
-        loss_scaler(loss, optimizer, clip_grad=args.max_grad_norm, 
-                    parameters=model.parameters(), create_graph=False,
-                    update_grad=(iter_i + 1) % args.grad_accumulate == 0)
-        if (iter_i + 1) % args.grad_accumulate == 0:
-            optimizer.zero_grad()
-            if model_ema is not None:
-                model_ema.update(model)
+        # Backward
+        loss.backward()
 
-        if torch.cuda.is_available():
-            torch.cuda.synchronize()
+        # Optimize
+        optimizer.step()
+        optimizer.zero_grad()
 
         # Logs
         lr = optimizer.param_groups[0]["lr"]
         metric_logger.update(loss=loss_value)
         metric_logger.update(lr=lr)
 
-        loss_value_reduce = all_reduce_mean(loss_value)
-        if tblogger is not None and (iter_i + 1) % args.grad_accumulate == 0:
-            """ We use epoch_1000x as the x-axis in tensorboard.
-            This calibrates different curves when batch size changes.
-            """
-            epoch_1000x = int((iter_i / len(data_loader) + epoch) * 1000)
-            tblogger.add_scalar('loss', loss_value_reduce, epoch_1000x)
-            tblogger.add_scalar('lr', lr, epoch_1000x)
-
     # gather the stats from all processes
-    metric_logger.synchronize_between_processes()
-    print_rank_0("Averaged stats: {}".format(metric_logger), local_rank)
+    print("Averaged stats: {}".format(metric_logger))
 
     return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
 
 
 @torch.no_grad()
-def evaluate(data_loader, model, device, local_rank):
+def evaluate(data_loader, model, device):
     criterion = torch.nn.CrossEntropyLoss()
 
     metric_logger = MetricLogger(delimiter="  ")
     header = 'Test:'
 
-    # switch to evaluation mode
+    # Switch to evaluation mode
     model.eval()
 
     for batch in metric_logger.log_every(data_loader, 10, header):
         images = batch[0]
-        target = batch[-1]
+        target = batch[1]
         images = images.to(device, non_blocking=True)
         target = target.to(device, non_blocking=True)
 
-        # compute output
-        with torch.cuda.amp.autocast():
-            output = model(images)
-            loss = criterion(output, target)
+        # Inference
+        output = model(images)
+
+        # Compute loss
+        loss = criterion(output, target)
 
+        # Compute accuracy
         acc1, acc5 = accuracy(output, target, topk=(1, 5))
 
         batch_size = images.shape[0]
@@ -123,9 +100,8 @@ def evaluate(data_loader, model, device, local_rank):
         metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
 
     # gather the stats from all processes
-    metric_logger.synchronize_between_processes()
-    print_rank_0('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
-                 .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss),
-                 local_rank)
+    print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
+          .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)
+          )
 
     return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

+ 204 - 0
image_classification/main.py

@@ -0,0 +1,204 @@
+import os
+import time
+import matplotlib.pyplot as plt
+import argparse
+import datetime
+
+# ---------------- Torch compoments ----------------
+import torch
+import torch.backends.cudnn as cudnn
+
+# ---------------- Dataset compoments ----------------
+from data import build_dataset, build_dataloader
+
+# ---------------- Model compoments ----------------
+from models import build_model
+
+# ---------------- Utils compoments ----------------
+from utils.misc import setup_seed, load_model, save_model
+from utils.optimzer import build_optimizer
+from utils.lr_scheduler import build_lr_scheduler, LinearWarmUpLrScheduler
+
+# ---------------- Training engine ----------------
+from engine import train_one_epoch, evaluate
+
+
+def parse_args():
+    parser = argparse.ArgumentParser()
+    # Basic
+    parser.add_argument('--seed', type=int, default=42,
+                        help='random seed.')
+    parser.add_argument('--cuda', action='store_true', default=False,
+                        help='use cuda')
+    parser.add_argument('--batch_size', type=int, default=256,
+                        help='batch size on all GPUs')
+    parser.add_argument('--num_workers', type=int, default=4,
+                        help='number of workers')
+    parser.add_argument('--path_to_save', type=str, default='weights/',
+                        help='path to save trained model.')
+    parser.add_argument('--eval', action='store_true', default=False,
+                        help='evaluate model.')
+    # Epoch
+    parser.add_argument('--wp_epoch', type=int, default=1, 
+                        help='warmup epoch for finetune with MAE pretrained')
+    parser.add_argument('--start_epoch', type=int, default=0, 
+                        help='start epoch for finetune with MAE pretrained')
+    parser.add_argument('--max_epoch', type=int, default=50, 
+                        help='max epoch')
+    parser.add_argument('--eval_epoch', type=int, default=5, 
+                        help='max epoch')
+    # Dataset
+    parser.add_argument('--dataset', type=str, default='cifar10',
+                        help='dataset name')
+    parser.add_argument('--root', type=str, default='/mnt/share/ssd2/dataset',
+                        help='path to dataset folder')
+    parser.add_argument('--img_dim', type=int, default=3, 
+                        help='input image dimension')
+    parser.add_argument('--num_classes', type=int, default=1000, 
+                        help='number of the classes')
+    # Model
+    parser.add_argument('-m', '--model', type=str, default='mlp4',
+                        help='model name')
+    parser.add_argument('--resume', default=None, type=str,
+                        help='keep training')
+    # Optimizer
+    parser.add_argument('-opt', '--optimizer', type=str, default='adamw',
+                        help='sgd, adam')
+    parser.add_argument('-wd', '--weight_decay', type=float, default=0.05,
+                        help='weight decay')
+    parser.add_argument('--base_lr', type=float, default=1e-3,
+                        help='learning rate for training model')
+    parser.add_argument('--min_lr', type=float, default=1e-6,
+                        help='the final lr')
+    # Lr scheduler
+    parser.add_argument('-lrs', '--lr_scheduler', type=str, default='step',
+                        help='lr scheduler: cosine, step')
+
+    return parser.parse_args()
+
+    
+def main():
+    args = parse_args()
+    print(args)
+    # set random seed
+    setup_seed(args.seed)
+
+    # Path to save model
+    path_to_save = os.path.join(args.path_to_save, args.dataset, args.model)
+    os.makedirs(path_to_save, exist_ok=True)
+    args.output_dir = path_to_save
+    
+    # ------------------------- Build CUDA -------------------------
+    if args.cuda:
+        if torch.cuda.is_available():
+            cudnn.benchmark = True
+            device = torch.device("cuda")
+        else:
+            print('There is no available GPU.')
+            args.cuda = False
+            device = torch.device("cpu")
+    else:
+        device = torch.device("cpu")
+
+    # ------------------------- Build Dataset -------------------------
+    train_dataset = build_dataset(args, is_train=True)
+    val_dataset   = build_dataset(args, is_train=False)
+
+    # ------------------------- Build Dataloader -------------------------
+    train_dataloader = build_dataloader(args, train_dataset, is_train=True)
+    val_dataloader   = build_dataloader(args, val_dataset,   is_train=False)
+
+    print('=================== Dataset Information ===================')
+    print("Dataset: ", args.dataset)
+    print('- train dataset size : ', len(train_dataset))
+    print('- val dataset size   : ', len(val_dataset))
+
+    # ------------------------- Build Model -------------------------
+    model = build_model(args)
+    model.train().to(device)
+    print(model)
+
+    # ------------------------- Build Criterion -------------------------
+    criterion = torch.nn.CrossEntropyLoss()
+
+    # ------------------------- Build Optimzier -------------------------
+    optimizer = build_optimizer(args, model)
+
+    # ------------------------- Build Lr Scheduler -------------------------
+    lr_scheduler_warmup = LinearWarmUpLrScheduler(args.base_lr, wp_iter=args.wp_epoch * len(train_dataloader))
+    lr_scheduler = build_lr_scheduler(args, optimizer)
+
+    # ------------------------- Build Criterion -------------------------
+    load_model(args, model, optimizer, lr_scheduler)
+
+    # ------------------------- Eval before Train Pipeline -------------------------
+    if args.eval:
+        print('evaluating ...')
+        test_stats = evaluate(val_dataloader, model, device)
+        print('Eval Results: [loss: %.2f][acc1: %.2f][acc5 : %.2f]' %
+                (test_stats['loss'], test_stats['acc1'], test_stats['acc5']), flush=True)
+        return
+
+    # ------------------------- Training Pipeline -------------------------
+    start_time = time.time()
+    max_accuracy = -1.0
+    print("=============== Start training for {} epochs ===============".format(args.max_epoch))
+    train_loss_logs = []
+    valid_loss_logs = []
+    valid_acc1_logs = []
+    for epoch in range(args.start_epoch, args.max_epoch):
+        # train one epoch
+        train_stats = train_one_epoch(args, device, model, train_dataloader, optimizer,
+                                      epoch, lr_scheduler_warmup, criterion)
+
+        # LR scheduler
+        if (epoch + 1) > args.wp_epoch:
+            lr_scheduler.step()
+
+        train_loss_logs.append((epoch, train_stats["loss"]))
+
+        # Evaluate
+        if (epoch % args.eval_epoch) == 0 or (epoch + 1 == args.max_epoch):
+            print("Evaluating ...")
+            test_stats = evaluate(val_dataloader, model, device)
+            print(f"Accuracy of the network on the {len(val_dataset)} test images: {test_stats['acc1']:.1f}%")
+            max_accuracy = max(max_accuracy, test_stats["acc1"])
+            print(f'Max accuracy: {max_accuracy:.2f}%')
+
+            # Save model
+            print('- saving the model after {} epochs ...'.format(epoch))
+            save_model(args, epoch, model, optimizer, lr_scheduler, test_stats["acc1"])
+
+            valid_acc1_logs.append((epoch, test_stats["acc1"]))
+            valid_loss_logs.append((epoch, test_stats["loss"]))
+
+    total_time = time.time() - start_time
+    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+    print('Training time {}'.format(total_time_str))
+
+    # --------------- Plot log curve ---------------
+    ## Training loss
+    epochs = [sample[0] for sample in train_loss_logs]
+    tloss  = [sample[1] for sample in train_loss_logs]
+    plt.plot(epochs, tloss, c='r', label='training loss')
+    plt.xlabel('epoch')
+    plt.ylabel('loss')
+    plt.title('Training & Validation loss curve')
+    ## Valid loss
+    epochs = [sample[0] for sample in valid_loss_logs]
+    vloss  = [sample[1] for sample in valid_loss_logs]
+    plt.plot(epochs, vloss, c='b', label='validation loss')
+    plt.show()
+    ## Valid acc1
+    epochs = [sample[0] for sample in valid_acc1_logs]
+    acc1   = [sample[1] for sample in valid_acc1_logs]
+    plt.plot(epochs, acc1, label='validation loss')
+    plt.xlabel('epoch')
+    plt.ylabel('top1 accuracy')
+    plt.title('Validation top-1 accuracy curve')
+    plt.show()
+
+
+
+if __name__ == "__main__":
+    main()

+ 20 - 0
image_classification/models/__init__.py

@@ -0,0 +1,20 @@
+from .mlp.build     import build_mlp
+from .convnet.build import build_convnet
+from .resnet.build  import build_resnet
+from .vit.build     import build_vit
+
+
+def build_model(args):
+    # --------------------------- ResNet series ---------------------------
+    if   'mlp' in args.model:
+        model = build_mlp(args)
+    elif 'convnet' in args.model:
+        model = build_convnet(args)
+    elif 'resnet' in args.model:
+        model = build_resnet(args)
+    elif 'vit' in args.model:
+        model = build_vit(args)
+    else:
+        raise NotImplementedError("Unknown model: {}".format(args.model))
+
+    return model

+ 16 - 0
image_classification/models/convnet/build.py

@@ -0,0 +1,16 @@
+from .convnet import ConvNet
+
+def build_convnet(args):
+    if args.model == "convnet":
+        model = ConvNet(img_size      = args.img_size,
+                        in_dim        = args.img_dim,
+                        hidden_dim    = 64,
+                        num_classes   = args.num_classes,
+                        act_type      = "relu",
+                        norm_type     = "bn",
+                        use_adavgpool = True)
+        
+    else:
+        raise NotImplementedError("Unknown model: {}".format(args.model))
+    
+    return model

+ 120 - 0
image_classification/models/convnet/convnet.py

@@ -0,0 +1,120 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .modules import ConvModule
+except:
+    from  modules import ConvModule
+
+
+# Convolutional Network
+class ConvNet(nn.Module):
+    def __init__(self,
+                 img_size      :int = 224,
+                 in_dim        :int = 3,
+                 hidden_dim    :int = 16,
+                 num_classes   :int = 10,
+                 act_type      :str = "relu",
+                 norm_type     :str = "bn",
+                 depthwise     :bool = False,
+                 use_adavgpool :bool = True,
+                 ) -> None:
+        super().__init__()
+        # ---------- Basic parameters ----------
+        self.img_size    = img_size
+        self.num_classes = num_classes
+        self.act_type    = act_type
+        self.norm_type   = norm_type
+        self.use_adavgpool = use_adavgpool
+        self.layer_dims    = [hidden_dim, hidden_dim*2, hidden_dim*4, hidden_dim*4]
+        # ---------- Model parameters ----------
+        self.layer_1 = nn.Sequential(
+            ConvModule(in_dim, hidden_dim,
+                       kernel_size=3, padding=1, stride=2,
+                       act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            ConvModule(hidden_dim, hidden_dim,
+                       kernel_size=3, padding=1, stride=1,
+                       act_type=act_type, norm_type=norm_type, depthwise=depthwise)            
+        )
+        self.layer_2 = nn.Sequential(
+            nn.MaxPool2d(kernel_size=2, stride=2),
+            ConvModule(hidden_dim, hidden_dim * 2,
+                       kernel_size=3, padding=1, stride=1,
+                       act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            ConvModule(hidden_dim * 2, hidden_dim * 2,
+                       kernel_size=3, padding=1, stride=1,
+                       act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        self.layer_3 = nn.Sequential(
+            nn.MaxPool2d(kernel_size=2, stride=2),
+            ConvModule(hidden_dim * 2, hidden_dim * 4,
+                       kernel_size=3, padding=1, stride=1,
+                       act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            ConvModule(hidden_dim * 4, hidden_dim * 4,
+                       kernel_size=3, padding=1, stride=1,
+                       act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        self.layer_4 = nn.Sequential(
+            ConvModule(hidden_dim * 4, hidden_dim * 4,
+                       kernel_size=3, padding=1, stride=1,
+                       act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            ConvModule(hidden_dim * 4, hidden_dim * 4,
+                       kernel_size=3, padding=1, stride=1,
+                       act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+
+        if use_adavgpool:
+            self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
+            self.fc      = nn.Linear(hidden_dim * 4, num_classes)
+        else:
+            self.avgpool = None
+            fc_in_dim    = (img_size // 8) ** 2 * (hidden_dim * 4)  # N = Co x Ho x W
+            self.fc      = nn.Linear(fc_in_dim , num_classes)
+
+    def forward(self, x):
+        """
+        Input:
+            x : (torch.Tensor) -> [B, C, H, W]
+        Output:
+            x : (torch.Tensor) -> [B, Nc], Nc is the number of the object categories.
+        """
+        # [B, C_in, H, W]   -> [B, C1, H/2, W/2]
+        x = self.layer_1(x)
+        # [B, C1, H/2, W/2] -> [B, C2, H/4, W/4]
+        x = self.layer_2(x)
+        # [B, C2, H/4, W/4] -> [B, C3, H/8, W/8]
+        x = self.layer_3(x)
+        # [B, C3, H/8, W/8] -> [B, C3, H/8, W/8]
+        x = self.layer_4(x)
+
+        if self.use_adavgpool:
+            x = self.avgpool(x)
+
+        # reshape [B, Co, Ho, Wo] to [B, N], N = Co x Ho x Wo
+        x = x.flatten(1)
+        x = self.fc(x)
+
+        return x
+
+
+if __name__ == "__main__":
+    bs, img_dim, img_size = 8, 3, 28
+    hidden_dim  = 16
+    num_classes = 10
+    
+    # Make an input data randomly
+    x = torch.randn(bs, img_dim, img_size, img_size)
+
+    # Build a MLP model
+    model = ConvNet(img_size      = img_size,
+                    in_dim        = img_dim,
+                    hidden_dim    = hidden_dim,
+                    num_classes   = num_classes,
+                    act_type      = 'relu',
+                    norm_type     = 'bn',
+                    depthwise     = False,
+                    use_adavgpool = False)
+
+    # Inference
+    output = model(x)
+    print(output.shape)

+ 86 - 0
image_classification/models/convnet/modules.py

@@ -0,0 +1,86 @@
+import torch
+import torch.nn as nn
+
+
+def get_activation(act_type=None):
+    if   act_type == 'sigmoid':
+        return nn.Sigmoid()
+    elif act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+        
+def get_norm(norm_type, dim):
+    if   norm_type == 'bn':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'ln':
+        return LayerNorm2d(dim)
+    elif norm_type == 'gn':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+    elif norm_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+
+class LayerNorm2d(nn.Module):
+    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(num_channels))
+        self.bias = nn.Parameter(torch.zeros(num_channels))
+        self.eps = eps
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        u = x.mean(1, keepdim=True)
+        s = (x - u).pow(2).mean(1, keepdim=True)
+        x = (x - u) / torch.sqrt(s + self.eps)
+        x = self.weight[:, None, None] * x + self.bias[:, None, None]
+        
+        return x
+    
+
+# Basic convolutional module
+class ConvModule(nn.Module):
+    def __init__(self,
+                 in_dim      :int,
+                 out_dim     :int,
+                 kernel_size :int  = 1,
+                 padding     :int  = 0,
+                 stride      :int  = 1,
+                 act_type    :str  = "relu",
+                 norm_type   :str  = "bn",
+                 depthwise   :bool = False) -> None:
+        super().__init__()
+        use_bias = False if norm_type is not None else True
+        self.depthwise = depthwise
+        if not depthwise:
+            self.conv = nn.Conv2d(in_channels=in_dim, out_channels=out_dim,
+                                kernel_size=kernel_size, padding=padding, stride=stride,
+                                bias=use_bias)
+            self.norm  = get_norm(norm_type, out_dim)
+        else:
+            self.conv1 = nn.Conv2d(in_channels=in_dim, out_channels=in_dim,
+                                   kernel_size=kernel_size, padding=padding, stride=stride, groups=in_dim,
+                                   bias=use_bias)
+            self.norm1 = get_norm(norm_type, in_dim)
+            self.conv2 = nn.Conv2d(in_channels=in_dim, out_channels=out_dim,
+                                   kernel_size=1, padding=0, stride=1,
+                                   bias=use_bias)
+            self.norm2 = get_norm(norm_type, out_dim)
+        self.act   = get_activation(act_type)
+
+    def forward(self, x):
+        if self.depthwise:
+            x = self.norm1(self.conv1(x))
+            x = self.act(self.norm2(self.conv2(x)))
+        else:
+            x = self.act(self.norm(self.conv(x)))
+
+        return x

+ 14 - 0
image_classification/models/mlp/build.py

@@ -0,0 +1,14 @@
+from .mlp import MLP
+
+def build_mlp(args):
+    if args.model == "mlp":
+        model = MLP(in_dim     = args.mlp_in_dim,
+                    inter_dim  = 1024,
+                    out_dim    = args.num_classes,
+                    act_type   = "relu",
+                    norm_type  = "bn")
+        
+    else:
+        raise NotImplementedError("Unknown model: {}".format(args.model))
+    
+    return model

+ 56 - 0
image_classification/models/mlp/mlp.py

@@ -0,0 +1,56 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .modules import SLP
+except:
+    from  modules import SLP
+
+
+# Multi Layer Perceptron
+class MLP(nn.Module):
+    def __init__(self,
+                 in_dim     :int,
+                 inter_dim  :int,
+                 out_dim    :int,
+                 act_type   :str = "sigmoid",
+                 norm_type  :str = "bn") -> None:
+        super().__init__()
+        self.stem   = SLP(in_dim, inter_dim, act_type, norm_type)
+        self.layers = nn.Sequential(
+            SLP(inter_dim, inter_dim, act_type, norm_type),
+            SLP(inter_dim, inter_dim, act_type, norm_type),
+            SLP(inter_dim, inter_dim, act_type, norm_type),
+            SLP(inter_dim, inter_dim, act_type, norm_type),            
+            )
+        self.fc     = nn.Linear(inter_dim, out_dim)
+
+    def forward(self, x):
+        """
+        Input:
+            x : (torch.Tensor) -> [B, C, H, W] or [B, C]
+        """
+        if len(x.shape) > 2:
+            x = x.flatten(1)
+
+        x = self.stem(x)
+        x = self.layers(x)
+        x = self.fc(x)
+
+        return x
+
+
+if __name__ == "__main__":
+    bs, c = 8, 256
+    hidden_dim  = 512
+    num_classes = 10
+    
+    # Make an input data randomly
+    x = torch.randn(bs, c)
+
+    # Build a MLP model
+    model = MLP(in_dim=c, inter_dim=hidden_dim, out_dim=num_classes, act_type='sigmoid', norm_type='bn')
+
+    # Inference
+    output = model(x)
+    print(output.shape)

+ 48 - 0
image_classification/models/mlp/modules.py

@@ -0,0 +1,48 @@
+import torch
+import torch.nn as nn
+
+
+def get_activation(act_type=None):
+    if   act_type == 'sigmoid':
+        return nn.Sigmoid()
+    elif act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+        
+def get_norm(norm_type, dim):
+    if   norm_type == 'bn':
+        return nn.BatchNorm1d(dim)
+    elif norm_type == 'ln':
+        return nn.LayerNorm(dim)
+    elif norm_type == 'gn':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+    elif norm_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+
+
+# Single Layer Perceptron
+class SLP(nn.Module):
+    def __init__(self,
+                 in_dim    :int,
+                 out_dim   :int,
+                 act_type  :str = "sigmoid",
+                 norm_type :str = "bn") -> None:
+        super().__init__()
+        use_bias = False if norm_type is not None else True
+        self.layer = nn.Linear(in_features=in_dim, out_features=out_dim, bias=use_bias)
+        self.norm  = get_norm(norm_type, out_dim)
+        self.act   = get_activation(act_type)
+
+    def forward(self, x):
+        return self.act(self.norm(self.layer(x)))

+ 21 - 0
image_classification/models/resnet/build.py

@@ -0,0 +1,21 @@
+from .resnet import ResNet
+from .modules import PlainResBlock, BottleneckResBlock
+
+
+def build_resnet(args):
+    if args.model == 'resnet18':
+        model = ResNet(in_dim=args.img_dim,
+                       block=PlainResBlock,
+                       expansion=1.0,
+                       num_blocks=[2, 2, 2, 2],
+                       )
+    elif args.model == 'resnet50':
+        model = ResNet(in_dim=args.img_dim,
+                       block=BottleneckResBlock,
+                       expansion=4.0,
+                       num_blocks=[3, 4, 6, 3],
+                       )
+    else:
+        raise NotImplementedError("Unknown resnet: {}".format(args.model))
+    
+    return model

+ 164 - 0
image_classification/models/resnet/modules.py

@@ -0,0 +1,164 @@
+import torch
+import torch.nn as nn
+
+
+def get_activation(act_type=None):
+    if   act_type == 'sigmoid':
+        return nn.Sigmoid()
+    elif act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+        
+def get_norm(norm_type, dim):
+    if   norm_type == 'bn':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'ln':
+        return LayerNorm2d(dim)
+    elif norm_type == 'gn':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+    elif norm_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+
+class LayerNorm2d(nn.Module):
+    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(num_channels))
+        self.bias = nn.Parameter(torch.zeros(num_channels))
+        self.eps = eps
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        u = x.mean(1, keepdim=True)
+        s = (x - u).pow(2).mean(1, keepdim=True)
+        x = (x - u) / torch.sqrt(s + self.eps)
+        x = self.weight[:, None, None] * x + self.bias[:, None, None]
+        
+        return x
+    
+class ConvModule(nn.Module):
+    def __init__(self,
+                 in_dim      :int,
+                 out_dim     :int,
+                 kernel_size :int  = 1,
+                 padding     :int  = 0,
+                 stride      :int  = 1,
+                 act_type    :str  = "relu",
+                 norm_type   :str  = "bn",
+                 depthwise   :bool = False) -> None:
+        super().__init__()
+        use_bias = False if norm_type is not None else True
+        self.depthwise = depthwise
+        if not depthwise:
+            self.conv = nn.Conv2d(in_channels=in_dim, out_channels=out_dim,
+                                kernel_size=kernel_size, padding=padding, stride=stride,
+                                bias=use_bias)
+            self.norm  = get_norm(norm_type, out_dim)
+        else:
+            self.conv1 = nn.Conv2d(in_channels=in_dim, out_channels=in_dim,
+                                   kernel_size=kernel_size, padding=padding, stride=stride, groups=in_dim,
+                                   bias=use_bias)
+            self.norm1 = get_norm(norm_type, in_dim)
+            self.conv2 = nn.Conv2d(in_channels=in_dim, out_channels=out_dim,
+                                   kernel_size=1, padding=0, stride=1,
+                                   bias=use_bias)
+            self.norm2 = get_norm(norm_type, out_dim)
+        self.act   = get_activation(act_type)
+
+    def forward(self, x):
+        if self.depthwise:
+            x = self.norm1(self.conv1(x))
+            x = self.act(self.norm2(self.conv2(x)))
+        else:
+            x = self.act(self.norm(self.conv(x)))
+
+        return x
+
+
+# -------------- ResNet's modules --------------
+class PlainResBlock(nn.Module):
+    def __init__(self, in_dim, inter_dim, out_dim, stride=1):
+        super().__init__()
+        # -------- Basic parameters --------
+        self.in_dim = in_dim
+        self.out_dim = out_dim
+        self.inter_dim = inter_dim
+        self.stride = stride
+        self.downsample = stride > 1 or in_dim != out_dim
+
+        # -------- Model parameters --------
+        self.conv_layer_1 = ConvModule(in_dim, inter_dim,
+                                       kernel_size=3, padding=1, stride=stride,
+                                       act_type='relu', norm_type='bn', depthwise=False)
+        self.conv_layer_2 = ConvModule(inter_dim, out_dim,
+                                       kernel_size=3, padding=1, stride=1,
+                                       act_type=None, norm_type='bn', depthwise=False)
+        self.out_act = nn.ReLU(inplace=True)
+
+        if self.downsample:
+            self.res_layer = ConvModule(in_dim, out_dim,
+                                       kernel_size=1, padding=0, stride=stride,
+                                       act_type=None, norm_type='bn', depthwise=False)
+        else:
+            self.res_layer = nn.Identity()
+
+    def forward(self, x):
+        out = self.conv_layer_1(x)
+        out = self.conv_layer_2(out)
+
+        x = self.res_layer(x)
+
+        out = x + out
+        out = self.out_act(out)
+
+        return out
+
+class BottleneckResBlock(nn.Module):
+    def __init__(self, in_dim, inter_dim, out_dim, stride=1):
+        super().__init__()
+        # -------- Basic parameters --------
+        self.in_dim = in_dim
+        self.out_dim = out_dim
+        self.stride = stride
+        self.downsample = stride > 1 or in_dim != out_dim
+
+        # -------- Model parameters --------
+        self.conv_layer_1 = ConvModule(in_dim, inter_dim,
+                                       kernel_size=1, padding=0, stride=1,
+                                       act_type='relu', norm_type='bn', depthwise=False)
+        self.conv_layer_2 = ConvModule(inter_dim, inter_dim,
+                                       kernel_size=3, padding=1, stride=stride,
+                                       act_type='relu', norm_type='bn', depthwise=False)
+        self.conv_layer_3 = ConvModule(inter_dim, out_dim,
+                                       kernel_size=1, padding=0, stride=1,
+                                       act_type=None, norm_type='bn', depthwise=False)
+        self.out_act = nn.ReLU(inplace=True)
+
+        if self.downsample:
+            self.res_layer = ConvModule(in_dim, out_dim,
+                                       kernel_size=1, padding=0, stride=stride,
+                                       act_type=None, norm_type='bn', depthwise=False)
+        else:
+            self.res_layer = nn.Identity()
+
+    def forward(self, x):
+        out = self.conv_layer_1(x)
+        out = self.conv_layer_2(out)
+        out = self.conv_layer_3(out)
+
+        x = self.res_layer(x)
+
+        out = x + out
+        out = self.out_act(out)
+
+        return out
+

+ 110 - 0
image_classification/models/resnet/resnet.py

@@ -0,0 +1,110 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .modules import ConvModule, PlainResBlock, BottleneckResBlock
+except:
+    from  modules import ConvModule, PlainResBlock, BottleneckResBlock
+
+
+class ResNet(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 block,
+                 expansion = 1.0,
+                 num_blocks = [2, 2, 2, 2],
+                 num_classes = 1000,
+                 ) -> None:
+        super().__init__()
+        # ----------- Basic parameters -----------
+        self.expansion = expansion
+        self.num_blocks = num_blocks
+        self.feat_dims  = [64,                      # C2 level
+                           round(64 * expansion),   # C2 level
+                           round(128 * expansion),  # C3 level
+                           round(256 * expansion),  # C4 level
+                           round(512 * expansion),  # C5 level
+                           ]
+        # ----------- Model parameters -----------
+        ## Backbone
+        self.layer_1 = nn.Sequential(
+            ConvModule(in_dim, self.feat_dims[0],
+                       kernel_size=7, padding=3, stride=2,
+                       act_type='relu', norm_type='bn', depthwise=False),
+            nn.MaxPool2d(kernel_size=(3, 3), padding=(1, 1), stride=(2, 2))
+        )
+        self.layer_2 = self.make_layer(block, self.feat_dims[0], self.feat_dims[1], depth=num_blocks[0], downsample=False)
+        self.layer_3 = self.make_layer(block, self.feat_dims[1], self.feat_dims[2], depth=num_blocks[1], downsample=True)
+        self.layer_4 = self.make_layer(block, self.feat_dims[2], self.feat_dims[3], depth=num_blocks[2], downsample=True)
+        self.layer_5 = self.make_layer(block, self.feat_dims[3], self.feat_dims[4], depth=num_blocks[3], downsample=True)
+
+        ## Classifier
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+        self.fc      = nn.Linear(self.feat_dims[4] , num_classes)
+        
+    def make_layer(self, block, in_dim, out_dim, depth=1, downsample=False):
+        stage_blocks = []
+        for i in range(depth):
+            if i == 0:
+                stride = 2 if downsample else 1
+                inter_dim = round(out_dim / self.expansion)
+                stage_blocks.append(block(in_dim, inter_dim, out_dim, stride))
+            else:
+                stride = 1
+                inter_dim = round(out_dim / self.expansion)
+                stage_blocks.append(block(out_dim, inter_dim, out_dim, stride))
+        
+        layers = nn.Sequential(*stage_blocks)
+
+        return layers
+    
+    def forward(self, x):
+        x = self.layer_1(x)
+        x = self.layer_2(x)
+        x = self.layer_3(x)
+        x = self.layer_4(x)
+        x = self.layer_5(x)
+
+        x = self.avgpool(x)
+        x = x.flatten(1)
+        x = self.fc(x)
+
+        return x
+
+
+def build_resnet(model_name='resnet18', img_dim=3):
+    if model_name == 'resnet18':
+        model = ResNet(in_dim=img_dim,
+                       block=PlainResBlock,
+                       expansion=1.0,
+                       num_blocks=[2, 2, 2, 2],
+                       )
+    elif model_name == 'resnet50':
+        model = ResNet(in_dim=img_dim,
+                       block=BottleneckResBlock,
+                       expansion=4.0,
+                       num_blocks=[3, 4, 6, 3],
+                       )
+    else:
+        raise NotImplementedError("Unknown resnet: {}".format(model_name))
+    
+    return model
+
+
+if __name__=='__main__':
+    import time
+
+    # 构建ResNet模型
+    model = build_resnet(model_name='resnet18')
+
+    # 打印模型结构
+    print(model)
+
+    # 随即成生数据
+    x = torch.randn(1, 3, 224, 224)
+
+    # 模型前向推理
+    t0 = time.time()
+    output = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)

+ 0 - 0
iclab/models/vit/build.py → image_classification/models/vit/build.py


+ 0 - 0
iclab/models/vit/modules.py → image_classification/models/vit/modules.py


+ 0 - 0
iclab/models/vit/vit.py → image_classification/models/vit/vit.py


+ 0 - 0
iclab/requirements.txt → image_classification/requirements.txt


+ 0 - 0
iclab/utils/__init__.py → image_classification/utils/__init__.py


+ 38 - 0
image_classification/utils/lr_scheduler.py

@@ -0,0 +1,38 @@
+import torch
+
+
+# Basic Warmup Scheduler
+class LinearWarmUpLrScheduler(object):
+    def __init__(self, base_lr=0.01, wp_iter=500, warmup_factor=0.00066667):
+        self.base_lr = base_lr
+        self.wp_iter = wp_iter
+        self.warmup_factor = warmup_factor
+
+    def set_lr(self, optimizer, cur_lr):
+        for param_group in optimizer.param_groups:
+            param_group['lr'] = cur_lr
+
+    def __call__(self, iter, optimizer):
+        # warmup
+        assert iter < self.wp_iter
+        alpha = iter / self.wp_iter
+        warmup_factor = self.warmup_factor * (1 - alpha) + alpha
+        tmp_lr = self.base_lr * warmup_factor
+        self.set_lr(optimizer, tmp_lr)
+
+
+def build_lr_scheduler(args, optimizer):
+    print("=================== LR Scheduler information ===================")
+    print("LR Scheduler: ", args.lr_scheduler)
+
+    if args.lr_scheduler == "step":
+        lr_step = [args.max_epoch // 2, args.max_epoch // 4 * 3]
+        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_step, gamma=0.1)
+        print("lr step: ", lr_step)
+    elif args.lr_scheduler == "cosine":
+        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch - args.wp_epoch - 1, eta_min=args.min_lr)
+    else:
+        raise NotImplementedError("Unknown lr scheduler: {}".format(args.lr_scheduler))
+    
+    return scheduler
+        

+ 198 - 0
image_classification/utils/misc.py

@@ -0,0 +1,198 @@
+import time
+import torch
+import numpy as np
+import random
+import datetime
+from collections import defaultdict, deque
+from pathlib import Path
+
+
+
+# ---------------------- Common functions ----------------------
+def setup_seed(seed=42):
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)
+    np.random.seed(seed)
+    random.seed(seed)
+    torch.backends.cudnn.deterministic = True
+
+def accuracy(output, target, topk=(1,)):
+    """Computes the accuracy over the k top predictions for the specified values of k"""
+    with torch.no_grad():
+        maxk = max(topk)
+        batch_size = target.size(0)
+
+        _, pred = output.topk(maxk, 1, True, True)
+        pred = pred.t()
+        correct = pred.eq(target.reshape(1, -1).expand_as(pred))
+
+        res = []
+        for k in topk:
+            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
+            res.append(correct_k.mul_(100.0 / batch_size))
+        return res
+
+class SmoothedValue(object):
+    """Track a series of values and provide access to smoothed values over a
+    window or the global series average.
+    """
+    def __init__(self, window_size=20, fmt=None):
+        if fmt is None:
+            fmt = "{median:.4f} ({global_avg:.4f})"
+        self.deque = deque(maxlen=window_size)
+        self.total = 0.0
+        self.count = 0
+        self.fmt = fmt
+
+    def update(self, value, n=1):
+        self.deque.append(value)
+        self.count += n
+        self.total += value * n
+
+    @property
+    def median(self):
+        d = torch.tensor(list(self.deque))
+        return d.median().item()
+
+    @property
+    def avg(self):
+        d = torch.tensor(list(self.deque), dtype=torch.float32)
+        return d.mean().item()
+
+    @property
+    def global_avg(self):
+        return self.total / self.count
+
+    @property
+    def max(self):
+        return max(self.deque)
+
+    @property
+    def value(self):
+        return self.deque[-1]
+
+    def __str__(self):
+        return self.fmt.format(
+            median=self.median,
+            avg=self.avg,
+            global_avg=self.global_avg,
+            max=self.max,
+            value=self.value)
+
+class MetricLogger(object):
+    def __init__(self, delimiter="\t"):
+        self.meters = defaultdict(SmoothedValue)
+        self.delimiter = delimiter
+
+    def update(self, **kwargs):
+        for k, v in kwargs.items():
+            if v is None:
+                continue
+            if isinstance(v, torch.Tensor):
+                v = v.item()
+            assert isinstance(v, (float, int))
+            self.meters[k].update(v)
+
+    def __getattr__(self, attr):
+        if attr in self.meters:
+            return self.meters[attr]
+        if attr in self.__dict__:
+            return self.__dict__[attr]
+        raise AttributeError("'{}' object has no attribute '{}'".format(
+            type(self).__name__, attr))
+
+    def __str__(self):
+        loss_str = []
+        for name, meter in self.meters.items():
+            loss_str.append(
+                "{}: {}".format(name, str(meter))
+            )
+        return self.delimiter.join(loss_str)
+
+    def add_meter(self, name, meter):
+        self.meters[name] = meter
+
+    def log_every(self, iterable, print_freq, header=None):
+        i = 0
+        if not header:
+            header = ''
+        start_time = time.time()
+        end = time.time()
+        iter_time = SmoothedValue(fmt='{avg:.4f}')
+        data_time = SmoothedValue(fmt='{avg:.4f}')
+        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+        log_msg = [
+            header,
+            '[{0' + space_fmt + '}/{1}]',
+            'eta: {eta}',
+            '{meters}',
+            'time: {time}',
+            'data: {data}'
+        ]
+        if torch.cuda.is_available():
+            log_msg.append('max mem: {memory:.0f}')
+        log_msg = self.delimiter.join(log_msg)
+        MB = 1024.0 * 1024.0
+        for obj in iterable:
+            data_time.update(time.time() - end)
+            yield obj
+            iter_time.update(time.time() - end)
+            if i % print_freq == 0 or i == len(iterable) - 1:
+                eta_seconds = iter_time.global_avg * (len(iterable) - i)
+                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+                if torch.cuda.is_available():
+                    print(log_msg.format(
+                        i, len(iterable), eta=eta_string,
+                        meters=str(self),
+                        time=str(iter_time), data=str(data_time),
+                        memory=torch.cuda.max_memory_allocated() / MB))
+                else:
+                    print(log_msg.format(
+                        i, len(iterable), eta=eta_string,
+                        meters=str(self),
+                        time=str(iter_time), data=str(data_time)))
+            i += 1
+            end = time.time()
+        total_time = time.time() - start_time
+        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+        print('{} Total time: {} ({:.4f} s / it)'.format(
+            header, total_time_str, total_time / len(iterable)))
+
+
+# ---------------------- Model functions ----------------------
+def load_model(args, model, optimizer, lr_scheduler):
+    if args.resume and args.resume.lower() != 'none':
+        print("=================== Load checkpoint ===================")
+        if args.resume.startswith('https'):
+            checkpoint = torch.hub.load_state_dict_from_url(
+                args.resume, map_location='cpu', check_hash=True)
+        else:
+            checkpoint = torch.load(args.resume, map_location='cpu')
+        model.load_state_dict(checkpoint['model'])
+        print("Resume checkpoint %s" % args.resume)
+        
+        if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
+            print('- Load optimizer from the checkpoint. ')
+            optimizer.load_state_dict(checkpoint['optimizer'])
+            args.start_epoch = checkpoint['epoch'] + 1
+
+        if 'lr_scheduler' in checkpoint:
+            print('- Load lr scheduler from the checkpoint. ')
+            lr_scheduler.load_state_dict(checkpoint.pop("lr_scheduler"))
+
+def save_model(args, epoch, model, optimizer, lr_scheduler, acc1=None):
+    output_dir = Path(args.output_dir)
+    epoch_name = str(epoch)
+    if acc1 is not None:
+        checkpoint_paths = [output_dir / ('checkpoint-{}-Acc1-{:.2f}.pth'.format(epoch_name, acc1))]
+    else:
+        checkpoint_paths = [output_dir / ('checkpoint-{}.pth'.format(epoch_name))]
+    for checkpoint_path in checkpoint_paths:
+        to_save = {
+            'model': model.state_dict(),
+            'optimizer': optimizer.state_dict(),
+            'lr_scheduler': lr_scheduler.state_dict(),
+            'epoch': epoch,
+            'args': args,
+        }
+        torch.save(to_save, checkpoint_path)

+ 34 - 0
image_classification/utils/optimzer.py

@@ -0,0 +1,34 @@
+import torch
+
+
+def build_optimizer(args, model):
+    print("=================== Optimizer information ===================")
+    print("Optimizer: ", args.optimizer)
+    
+    ## learning rate
+    if args.optimizer == "adamw":
+        batch_base = 256 if "vit" in args.model else 1024
+        args.base_lr = args.base_lr / batch_base * args.batch_size
+        optimizer = torch.optim.AdamW(model.parameters(),
+                                      lr=args.base_lr,
+                                      weight_decay=args.weight_decay)
+        print('- base lr: ', args.base_lr)
+        print('- min  lr: ', args.min_lr)
+        print('- weight_decay: ', args.weight_decay)
+    elif args.optimizer == "sgd":
+        batch_base = 128
+        args.base_lr = args.base_lr / batch_base * args.batch_size
+        optimizer = torch.optim.SGD(model.parameters(),
+                                    lr=args.base_lr,
+                                    momentum=0.9,
+                                    weight_decay=args.weight_decay)
+        print('- base lr: ', args.base_lr)
+        print('- min  lr: ', args.min_lr)
+        print('- momentum: ', 0.9)
+        print('- weight decay: ', args.weight_decay)
+    else:
+        raise NotImplementedError("Unknown optimizer: {}".format(args.optimizer))
+
+    print('- min  lr: ', args.min_lr)
+
+    return optimizer

+ 12 - 0
masked_image_modeling/.gitignore

@@ -0,0 +1,12 @@
+*.pt
+*.pth
+*.pkl
+*.onnx
+*.pyc
+*.zip
+weights
+__pycache__
+data/cifar/
+data/cifar_data/
+data/mnist_data/
+vis_results/

+ 73 - 0
masked_image_modeling/README.md

@@ -0,0 +1,73 @@
+# Masked AutoEncoder
+
+## 1. Pretrain
+We have kindly provided the bash script `main_pretrain.sh` file for pretraining. You can modify some hyperparameters in the script file according to your own needs.
+
+```Shell
+cd Vision-Pretraining-Tutorial/masked_image_modeling/
+python main_pretrain.py --cuda \
+                        --dataset cifar10 \
+                        --model vit_t \
+                        --mask_ratio 0.75 \
+                        --batch_size 128 \
+                        --optimizer adamw \
+                        --weight_decay 0.05 \
+                        --lr_scheduler cosine \
+                        --base_lr 0.00015 \
+                        --min_lr 0.0 \
+                        --max_epoch 400 \
+                        --eval_epoch 20
+```
+
+## 2. Finetune
+We have kindly provided the bash script `main_finetune.sh` file for finetuning. You can modify some hyperparameters in the script file according to your own needs.
+
+```Shell
+cd Vision-Pretraining-Tutorial/masked_image_modeling/
+python main_finetune.py --cuda \
+                        --dataset cifar10 \
+                        --model vit_t \
+                        --batch_size 256 \
+                        --optimizer adamw \
+                        --weight_decay 0.05 \
+                        --base_lr 0.0005 \
+                        --min_lr 0.000001 \
+                        --max_epoch 100 \
+                        --wp_epoch 5 \
+                        --eval_epoch 5 \
+                        --pretrained path/to/vit_t.pth
+```
+## 3. Evaluate 
+- Evaluate the `top1 & top5` accuracy of `ViT-Tiny` on CIFAR10 dataset:
+```Shell
+python main_finetune.py --cuda \
+                        --dataset cifar10 \
+                        -m vit_t \
+                        --batch_size 256 \
+                        --eval \
+                        --resume path/to/vit_t_cifar10.pth
+```
+
+
+## 4. Visualize Image Reconstruction
+- Evaluate `ViT-Tiny` pretrained by MAE framework on CIFAR10 dataset:
+```Shell
+python main_pretrain.py --cuda \
+                        --dataset cifar10 \
+                        -m vit_t \
+                        --resume path/to/mae_vit_t_cifar10.pth \
+                        --eval \
+                        --batch_size 1
+```
+
+
+## 5. Experiments
+- On CIFAR10
+
+| Method |  Model  | Epoch | Top 1    | Weight |  MAE weight  |
+|  :---: |  :---:  | :---: | :---:    | :---:  |    :---:     |
+|  MAE   |  ViT-T  | 100   |   91.2   | [ckpt](https://github.com/yjh0410/MAE/releases/download/checkpoints/ViT-T_Cifar10.pth) | [ckpt](https://github.com/yjh0410/MAE/releases/download/checkpoints/MAE_ViT-T_Cifar10.pth) |
+
+
+## 6. Acknowledgment
+Thank you to **Kaiming He** for his inspiring work on [MAE](http://openaccess.thecvf.com/content/CVPR2022/papers/He_Masked_Autoencoders_Are_Scalable_Vision_Learners_CVPR_2022_paper.pdf). His research effectively elucidates the semantic distinctions between vision and language, offering valuable insights for subsequent vision-related studies. I would also like to express my gratitude for the official source code of [MAE](https://github.com/facebookresearch/mae). Additionally, I appreciate the efforts of [**IcarusWizard**](https://github.com/IcarusWizard) for reproducing the [MAE](https://github.com/IcarusWizard/MAE) implementation.

+ 38 - 0
masked_image_modeling/data/__init__.py

@@ -0,0 +1,38 @@
+import torch
+
+from .cifar import CifarDataset
+from .custom import CustomDataset
+
+
+def build_dataset(args, is_train=False):
+    # ----------------- CIFAR dataset -----------------
+    if args.dataset == 'cifar10':
+        args.num_classes = 10
+        args.img_dim = 3
+        args.img_size = 32
+        args.patch_size = 4
+        return CifarDataset(is_train)
+        
+    # ----------------- Customed dataset -----------------
+    elif args.dataset == 'custom':
+        assert args.num_classes is not None and isinstance(args.num_classes, int)
+        args.img_size = 224
+        args.patch_size = 16
+        return CustomDataset(args, is_train)
+    
+    else:
+        print("Unknown dataset: {}".format(args.dataset))
+    
+
+def build_dataloader(args, dataset, is_train=False):
+    if is_train:
+        sampler = torch.utils.data.RandomSampler(dataset)
+        batch_sampler_train = torch.utils.data.BatchSampler(
+            sampler, args.batch_size, drop_last=True if is_train else False)
+        dataloader = torch.utils.data.DataLoader(
+            dataset, batch_sampler=batch_sampler_train, num_workers=args.num_workers, pin_memory=True)
+    else:
+        dataloader = torch.utils.data.DataLoader(
+            dataset=dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
+
+    return dataloader

+ 65 - 0
masked_image_modeling/data/cifar.py

@@ -0,0 +1,65 @@
+import os
+import numpy as np
+import torch.utils.data as data
+import torchvision.transforms as T
+from torchvision.datasets import CIFAR10
+
+
+class CifarDataset(data.Dataset):
+    def __init__(self, is_train=False):
+        super().__init__()
+        # ----------------- basic parameters -----------------
+        self.pixel_mean = [0.5, 0.5, 0.5]
+        self.pixel_std =  [0.5, 0.5, 0.5]
+        self.is_train  = is_train
+        self.image_set = 'train' if is_train else 'val'
+        # ----------------- dataset & transforms -----------------
+        self.transform = self.build_transform()
+        path = os.path.dirname(os.path.abspath(__file__))
+        if is_train:
+            self.dataset = CIFAR10(os.path.join(path, 'cifar_data/'), train=True, download=True, transform=self.transform)
+        else:
+            self.dataset = CIFAR10(os.path.join(path, 'cifar_data/'), train=False, download=True, transform=self.transform)
+
+    def __len__(self):
+        return len(self.dataset)
+    
+    def __getitem__(self, index):
+        image, target = self.dataset[index]
+            
+        return image, target
+    
+    def pull_image(self, index):
+        # laod data
+        image, target = self.dataset[index]
+
+        # denormalize image
+        image = image.permute(1, 2, 0).numpy()
+        image = (image * self.pixel_std + self.pixel_mean) * 255.
+        image = image.astype(np.uint8)
+        image = image.copy()
+
+        return image, target
+
+    def build_transform(self):
+        if self.is_train:
+            transforms = T.Compose([T.ToTensor(), T.Normalize(0.5, 0.5)])
+        else:
+            transforms = T.Compose([T.ToTensor(), T.Normalize(0.5, 0.5)])
+
+        return transforms
+
+if __name__ == "__main__":
+    import cv2
+    
+    # dataset
+    dataset = CifarDataset(is_train=True)  
+    print('Dataset size: ', len(dataset))
+
+    for i in range(len(dataset)):
+        image, target = dataset.pull_image(i)
+        # to BGR
+        image = image[..., (2, 1, 0)]
+
+        cv2.imshow('image', image)
+        cv2.waitKey(0)

+ 87 - 0
masked_image_modeling/data/custom.py

@@ -0,0 +1,87 @@
+import os
+import PIL
+import numpy as np
+import torch.utils.data as data
+import torchvision.transforms as T
+from torchvision.datasets import ImageFolder
+
+
+class CustomDataset(data.Dataset):
+    def __init__(self, args, is_train=False):
+        super().__init__()
+        # ----------------- basic parameters -----------------
+        self.args = args
+        self.is_train   = is_train
+        self.pixel_mean = [0.485, 0.456, 0.406]
+        self.pixel_std  = [0.229, 0.224, 0.225]
+        print("Pixel mean: {}".format(self.pixel_mean))
+        print("Pixel std:  {}".format(self.pixel_std))
+        self.image_set = 'train' if is_train else 'val'
+        self.data_path = os.path.join(args.root, self.image_set)
+        # ----------------- dataset & transforms -----------------
+        self.transform = self.build_transform()
+        self.dataset = ImageFolder(root=self.data_path, transform=self.transform)
+
+    def __len__(self):
+        return len(self.dataset)
+    
+    def __getitem__(self, index):
+        image, target = self.dataset[index]
+
+        return image, target
+    
+    def pull_image(self, index):
+        # laod data
+        image, target = self.dataset[index]
+
+        # denormalize image
+        image = image.permute(1, 2, 0).numpy()
+        image = (image * self.pixel_std + self.pixel_mean) * 255.
+        image = image.astype(np.uint8)
+        image = image.copy()
+
+        return image, target
+
+    def build_transform(self):
+        if self.is_train:
+            transforms = T.Compose([
+                            T.RandomResizedCrop(224),
+                            T.RandomHorizontalFlip(0.5),
+                            T.ToTensor(),
+                            T.Normalize(self.pixel_mean,
+                                        self.pixel_std)])
+        else:
+            transforms = T.Compose([
+                T.Resize(224, interpolation=PIL.Image.BICUBIC),
+                T.CenterCrop(224),
+                T.ToTensor(),
+                T.Normalize(self.pixel_mean, self.pixel_std),
+            ])
+
+        return transforms
+
+
+if __name__ == "__main__":
+    import cv2
+    import argparse
+    
+    parser = argparse.ArgumentParser(description='Custom-Dataset')
+
+    # opt
+    parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/classification/dataset/Animals/',
+                        help='data root')
+    parser.add_argument('--img_size', default=224, type=int,
+                        help='input image size.')
+    args = parser.parse_args()
+  
+    # Dataset
+    dataset = CustomDataset(args, is_train=True)  
+    print('Dataset size: ', len(dataset))
+
+    for i in range(len(dataset)):
+        image, target = dataset.pull_image(i)
+        # to BGR
+        image = image[..., (2, 1, 0)]
+
+        cv2.imshow('image', image)
+        cv2.waitKey(0)

+ 107 - 0
masked_image_modeling/engine_finetune.py

@@ -0,0 +1,107 @@
+import sys
+import math
+import torch
+
+from utils.misc import MetricLogger, SmoothedValue, accuracy
+
+
+def train_one_epoch(args,
+                    device,
+                    model,
+                    data_loader,
+                    optimizer,
+                    epoch,
+                    lr_scheduler_warmup,
+                    criterion,
+                    ):
+    model.train(True)
+    metric_logger = MetricLogger(delimiter="  ")
+    metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
+    header = 'Epoch: [{}]'.format(epoch)
+    print_freq = 20
+    epoch_size = len(data_loader)
+
+    optimizer.zero_grad()
+
+    # train one epoch
+    for iter_i, (images, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+        ni = iter_i + epoch * epoch_size
+        nw = args.wp_epoch * epoch_size
+
+        # Warmup
+        if nw > 0 and ni < nw:
+            lr_scheduler_warmup(ni, optimizer)
+        elif ni == nw:
+            print("Warmup stage is over.")
+            lr_scheduler_warmup.set_lr(optimizer, args.base_lr)
+
+        # To device
+        images = images.to(device, non_blocking=True)
+        targets = targets.to(device, non_blocking=True)
+
+        # Inference
+        output = model(images)
+
+        # Compute loss
+        loss = criterion(output, targets)
+
+        # Check loss
+        loss_value = loss.item()
+        if not math.isfinite(loss_value):
+            print("Loss is {}, stopping training".format(loss_value))
+            sys.exit(1)
+
+        # Backward
+        loss.backward()
+
+        # Optimize
+        optimizer.step()
+        optimizer.zero_grad()
+
+        # Logs
+        lr = optimizer.param_groups[0]["lr"]
+        metric_logger.update(loss=loss_value)
+        metric_logger.update(lr=lr)
+
+    # gather the stats from all processes
+    print("Averaged stats: {}".format(metric_logger))
+
+    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+
+@torch.no_grad()
+def evaluate(data_loader, model, device):
+    criterion = torch.nn.CrossEntropyLoss()
+
+    metric_logger = MetricLogger(delimiter="  ")
+    header = 'Test:'
+
+    # switch to evaluation mode
+    model.eval()
+
+    for batch in metric_logger.log_every(data_loader, 10, header):
+        images = batch[0]
+        target = batch[1]
+        images = images.to(device, non_blocking=True)
+        target = target.to(device, non_blocking=True)
+
+        # Inference
+        output = model(images)
+
+        # Compute loss
+        loss = criterion(output, target)
+
+        # Compute accuracy
+        acc1, acc5 = accuracy(output, target, topk=(1, 5))
+
+        batch_size = images.shape[0]
+        metric_logger.update(loss=loss.item())
+        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
+        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
+
+    # gather the stats from all processes
+    print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
+          .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss),
+          )
+
+    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

+ 64 - 0
masked_image_modeling/engine_pretrain.py

@@ -0,0 +1,64 @@
+import sys
+import math
+
+from utils.misc import MetricLogger, SmoothedValue
+
+
+def train_one_epoch(args,
+                    device,
+                    model,
+                    data_loader,
+                    optimizer,
+                    epoch,
+                    lr_scheduler_warmup,
+                    ):
+    model.train(True)
+    metric_logger = MetricLogger(delimiter="  ")
+    metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
+    header = 'Epoch: [{}]'.format(epoch)
+    print_freq = 20
+    epoch_size = len(data_loader)
+
+    # Train one epoch
+    for iter_i, (images, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+        ni = iter_i + epoch * epoch_size
+        nw = args.wp_epoch * epoch_size
+        
+        # Warmup
+        if nw > 0 and ni < nw:
+            lr_scheduler_warmup(ni, optimizer)
+        elif ni == nw:
+            print("Warmup stage is over.")
+            lr_scheduler_warmup.set_lr(optimizer, args.base_lr)
+
+        # To device
+        images = images.to(device, non_blocking=True)
+
+        # Inference
+        output = model(images)
+
+        # Compute loss
+        loss = output["loss"]
+
+        # Check loss
+        loss_value = loss.item()
+        if not math.isfinite(loss_value):
+            print("Loss is {}, stopping training".format(loss_value))
+            sys.exit(1)
+
+        # Backward
+        loss.backward()
+
+        # Optimize
+        optimizer.step()
+        optimizer.zero_grad()
+
+        # Logs
+        lr = optimizer.param_groups[0]["lr"]
+        metric_logger.update(loss=loss_value)
+        metric_logger.update(lr=lr)
+
+    # gather the stats from all processes
+    print("Averaged stats: {}".format(metric_logger))
+
+    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

+ 175 - 0
masked_image_modeling/main_finetune.py

@@ -0,0 +1,175 @@
+import os
+import time
+import argparse
+import datetime
+
+# ---------------- Torch compoments ----------------
+import torch
+import torch.backends.cudnn as cudnn
+
+# ---------------- Dataset compoments ----------------
+from data import build_dataset, build_dataloader
+
+# ---------------- Model compoments ----------------
+from models import build_model
+
+# ---------------- Utils compoments ----------------
+from utils.misc import setup_seed, load_model, save_model
+from utils.optimizer import build_optimizer
+from utils.lr_scheduler import build_lr_scheduler, LinearWarmUpLrScheduler
+
+# ---------------- Training engine ----------------
+from engine_finetune import train_one_epoch, evaluate
+
+
+def parse_args():
+    parser = argparse.ArgumentParser()
+    # Input
+    parser.add_argument('--img_dim', type=int, default=3,
+                        help='3 for RGB; 1 for Gray.')    
+    parser.add_argument('--patch_size', type=int, default=16,
+                        help='patch_size.')
+    # Basic
+    parser.add_argument('--seed', type=int, default=42,
+                        help='random seed.')
+    parser.add_argument('--cuda', action='store_true', default=False,
+                        help='use cuda')
+    parser.add_argument('--batch_size', type=int, default=256,
+                        help='batch size on all GPUs')
+    parser.add_argument('--num_workers', type=int, default=4,
+                        help='number of workers')
+    parser.add_argument('--path_to_save', type=str, default='weights/',
+                        help='path to save trained model.')
+    parser.add_argument('--eval', action='store_true', default=False,
+                        help='evaluate model.')
+    # Epoch
+    parser.add_argument('--wp_epoch', type=int, default=5, 
+                        help='warmup epoch')
+    parser.add_argument('--start_epoch', type=int, default=0, 
+                        help='start epoch')
+    parser.add_argument('--max_epoch', type=int, default=50, 
+                        help='max epoch')
+    parser.add_argument('--eval_epoch', type=int, default=5, 
+                        help='max epoch')
+    # Dataset
+    parser.add_argument('--dataset', type=str, default='cifar10',
+                        help='dataset name')
+    parser.add_argument('--root', type=str, default='/mnt/share/ssd2/dataset',
+                        help='path to dataset folder')
+    parser.add_argument('--num_classes', type=int, default=None, 
+                        help='number of classes.')
+    # Model
+    parser.add_argument('-m', '--model', type=str, default='vit_t',
+                        help='model name')
+    parser.add_argument('--pretrained', default=None, type=str,
+                        help='load pretrained weight.')
+    parser.add_argument('--resume', default=None, type=str,
+                        help='keep training')
+    parser.add_argument('--drop_path', type=float, default=0.1,
+                        help='drop_path')
+    # Optimizer
+    parser.add_argument('-opt', '--optimizer', type=str, default='adamw',
+                        help='sgd, adam')
+    parser.add_argument('-wd', '--weight_decay', type=float, default=0.05,
+                        help='weight decay')
+    parser.add_argument('--base_lr', type=float, default=0.001,
+                        help='learning rate for training model')
+    parser.add_argument('--min_lr', type=float, default=0,
+                        help='the final lr')
+    # Lr scheduler
+    parser.add_argument('-lrs', '--lr_scheduler', type=str, default='cosine',
+                        help='step, cosine')
+
+    return parser.parse_args()
+
+    
+def main():
+    args = parse_args()
+    # set random seed
+    setup_seed(args.seed)
+
+    # Path to save model
+    path_to_save = os.path.join(args.path_to_save, args.dataset, "finetune", args.model)
+    os.makedirs(path_to_save, exist_ok=True)
+    args.output_dir = path_to_save
+
+    # ------------------------- Build CUDA -------------------------
+    if args.cuda:
+        if torch.cuda.is_available():
+            cudnn.benchmark = True
+            device = torch.device("cuda")
+        else:
+            print('There is no available GPU.')
+            args.cuda = False
+            device = torch.device("cpu")
+    else:
+        device = torch.device("cpu")
+
+    # ------------------------- Build Dataset -------------------------
+    train_dataset = build_dataset(args, is_train=True)
+    val_dataset   = build_dataset(args, is_train=False)
+
+    # ------------------------- Build Dataloader -------------------------
+    train_dataloader = build_dataloader(args, train_dataset, is_train=True)
+    val_dataloader   = build_dataloader(args, val_dataset,   is_train=False)
+
+    print('=================== Dataset Information ===================')
+    print('Train dataset size : ', len(train_dataset))
+    print('Val dataset size   : ', len(val_dataset))
+
+    # ------------------------- Build Model -------------------------
+    model = build_model(args, model_type='cls')
+    model.train().to(device)
+    print(model)
+
+    # ------------------------- Build Optimzier -------------------------
+    optimizer = build_optimizer(args, model)
+
+    # ------------------------- Build Lr Scheduler -------------------------
+    lr_scheduler_warmup = LinearWarmUpLrScheduler(args.base_lr, wp_iter=args.wp_epoch * len(train_dataloader))
+    lr_scheduler = build_lr_scheduler(args, optimizer)
+
+
+    # ------------------------- Build Criterion -------------------------
+    criterion = torch.nn.CrossEntropyLoss()
+    load_model(args, model, optimizer, lr_scheduler)
+
+    # ------------------------- Eval before Train Pipeline -------------------------
+    if args.eval:
+        print('evaluating ...')
+        test_stats = evaluate(val_dataloader, model, device)
+        print('Eval Results: [loss: %.2f][acc1: %.2f][acc5 : %.2f]' %
+                (test_stats['loss'], test_stats['acc1'], test_stats['acc5']), flush=True)
+        return
+
+    # ------------------------- Training Pipeline -------------------------
+    start_time = time.time()
+    max_accuracy = -1.0
+    print("=============== Start training for {} epochs ===============".format(args.max_epoch))
+    for epoch in range(args.start_epoch, args.max_epoch):
+        # Train one epoch
+        train_one_epoch(args, device, model, train_dataloader, optimizer,
+                        epoch, lr_scheduler_warmup, criterion)
+
+        # LR scheduler
+        if (epoch + 1) > args.wp_epoch:
+            lr_scheduler.step()
+
+        # Evaluate
+        if (epoch % args.eval_epoch) == 0 or (epoch + 1 == args.max_epoch):
+            test_stats = evaluate(val_dataloader, model, device)
+            print(f"Accuracy of the network on the {len(val_dataset)} test images: {test_stats['acc1']:.1f}%")
+            max_accuracy = max(max_accuracy, test_stats["acc1"])
+            print(f'Max accuracy: {max_accuracy:.2f}%')
+
+            # Save model
+            print('- saving the model after {} epochs ...'.format(epoch))
+            save_model(args, epoch, model, optimizer, lr_scheduler, acc1=max_accuracy)
+
+    total_time = time.time() - start_time
+    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+    print('Training time {}'.format(total_time_str))
+
+
+if __name__ == "__main__":
+    main()

+ 211 - 0
masked_image_modeling/main_pretrain.py

@@ -0,0 +1,211 @@
+import os
+import cv2
+import time
+import datetime
+import argparse
+import numpy as np
+
+# ---------------- Torch compoments ----------------
+import torch
+import torch.backends.cudnn as cudnn
+
+# ---------------- Dataset compoments ----------------
+from data import build_dataset, build_dataloader
+from models import build_model
+
+# ---------------- Utils compoments ----------------
+from utils.misc import setup_seed
+from utils.misc import load_model, save_model, unpatchify
+from utils.optimizer import build_optimizer
+from utils.lr_scheduler import build_lr_scheduler, LinearWarmUpLrScheduler
+
+# ---------------- Training engine ----------------
+from engine_pretrain import train_one_epoch
+
+
+def parse_args():
+    parser = argparse.ArgumentParser()
+    # Basic
+    parser.add_argument('--seed', type=int, default=42,
+                        help='random seed.')
+    parser.add_argument('--cuda', action='store_true', default=False,
+                        help='use cuda')
+    parser.add_argument('--batch_size', type=int, default=256,
+                        help='batch size on all GPUs')
+    parser.add_argument('--num_workers', type=int, default=4,
+                        help='number of workers')
+    parser.add_argument('--path_to_save', type=str, default='weights/',
+                        help='path to save trained model.')
+    parser.add_argument('--eval', action='store_true', default=False,
+                        help='evaluate model.')
+    # Epoch
+    parser.add_argument('--wp_epoch', type=int, default=20, 
+                        help='warmup epoch for finetune with MAE pretrained')
+    parser.add_argument('--start_epoch', type=int, default=0, 
+                        help='start epoch for finetune with MAE pretrained')
+    parser.add_argument('--eval_epoch', type=int, default=10, 
+                        help='warmup epoch for finetune with MAE pretrained')
+    parser.add_argument('--max_epoch', type=int, default=200, 
+                        help='max epoch')
+    # Dataset
+    parser.add_argument('--dataset', type=str, default='cifar10',
+                        help='dataset name')
+    parser.add_argument('--root', type=str, default='/mnt/share/ssd2/dataset',
+                        help='path to dataset folder')
+    parser.add_argument('--num_classes', type=int, default=None, 
+                        help='number of classes.')
+    # Model
+    parser.add_argument('-m', '--model', type=str, default='vit_t',
+                        help='model name')
+    parser.add_argument('--resume', default=None, type=str,
+                        help='keep training')
+    parser.add_argument('--drop_path', type=float, default=0.,
+                        help='drop_path')
+    parser.add_argument('--mask_ratio', type=float, default=0.75,
+                        help='mask ratio.')    
+    # Optimizer
+    parser.add_argument('-opt', '--optimizer', type=str, default='adamw',
+                        help='sgd, adam')
+    parser.add_argument('-wd', '--weight_decay', type=float, default=0.05,
+                        help='weight decay')
+    parser.add_argument('--base_lr', type=float, default=0.00015,
+                        help='learning rate for training model')
+    parser.add_argument('--min_lr', type=float, default=0,
+                        help='the final lr')
+    # Optimizer
+    parser.add_argument('-lrs', '--lr_scheduler', type=str, default='cosine',
+                        help='step, cosine')
+
+    return parser.parse_args()
+
+    
+def main():
+    args = parse_args()
+    # set random seed
+    setup_seed(args.seed)
+
+    # Path to save model
+    path_to_save = os.path.join(args.path_to_save, args.dataset, "pretrained", args.model)
+    os.makedirs(path_to_save, exist_ok=True)
+    args.output_dir = path_to_save
+    
+    # ------------------------- Build CUDA -------------------------
+    if args.cuda:
+        if torch.cuda.is_available():
+            cudnn.benchmark = True
+            device = torch.device("cuda")
+        else:
+            print('There is no available GPU.')
+            args.cuda = False
+            device = torch.device("cpu")
+    else:
+        device = torch.device("cpu")
+
+    # ------------------------- Build Dataset -------------------------
+    train_dataset = build_dataset(args, is_train=True)
+
+    # ------------------------- Build Dataloader -------------------------
+    train_dataloader = build_dataloader(args, train_dataset, is_train=True)
+    print('=================== Dataset Information ===================')
+    print('Train dataset size : {}'.format(len(train_dataset)))
+
+   # ------------------------- Build Model -------------------------
+    model = build_model(args, model_type='mae')
+    model.train().to(device)
+    print(model)
+
+    # ------------------------- Build Optimzier -------------------------
+    optimizer = build_optimizer(args, model)
+
+    # ------------------------- Build Lr Scheduler -------------------------
+    lr_scheduler_warmup = LinearWarmUpLrScheduler(args.base_lr, wp_iter=args.wp_epoch * len(train_dataloader))
+    lr_scheduler = build_lr_scheduler(args, optimizer)
+
+    # ------------------------- Build checkpoint -------------------------
+    load_model(args, model, optimizer, lr_scheduler)
+
+    # ------------------------- Eval before Train Pipeline -------------------------
+    if args.eval:
+        print('visualizing ...')
+        visualize(args, device, model)
+        return
+
+    # ------------------------- Training Pipeline -------------------------
+    start_time = time.time()
+    print("=================== Start training for {} epochs ===================".format(args.max_epoch))
+    for epoch in range(args.start_epoch, args.max_epoch):
+        # Train one epoch
+        train_one_epoch(args, device, model, train_dataloader,
+                        optimizer, epoch, lr_scheduler_warmup)
+
+        # LR scheduler
+        if (epoch + 1) > args.wp_epoch:
+            lr_scheduler.step()
+
+        # Evaluate
+        if epoch % args.eval_epoch == 0 or epoch + 1 == args.max_epoch:
+            print('- saving the model after {} epochs ...'.format(epoch))
+            save_model(args, epoch, model, optimizer, lr_scheduler, mae_task=True)
+
+    total_time = time.time() - start_time
+    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+    print('Training time {}'.format(total_time_str))
+
+def visualize(args, device, model):
+    # test dataset
+    val_dataset = build_dataset(args, is_train=False)
+    val_dataloader = build_dataloader(args, val_dataset, is_train=False)
+
+    # save path
+    save_path = "vis_results/{}/{}".format(args.dataset, args.model)
+    os.makedirs(save_path, exist_ok=True)
+
+    # switch to evaluate mode
+    model.eval()
+    patch_size = args.patch_size
+    pixel_mean = val_dataloader.dataset.pixel_mean
+    pixel_std  = val_dataloader.dataset.pixel_std
+
+    with torch.no_grad():
+        for i, (images, target) in enumerate(val_dataloader):
+            # To device
+            images = images.to(device, non_blocking=True)
+            target = target.to(device, non_blocking=True)
+
+            # Inference
+            output = model(images)
+
+            # Denormalize input image
+            org_img = images[0].permute(1, 2, 0).cpu().numpy()
+            org_img = (org_img * pixel_std + pixel_mean) * 255.
+            org_img = org_img.astype(np.uint8)
+
+            # 调整mask的格式:[B, H*W] -> [B, H*W, p*p*3]
+            mask = output['mask'].unsqueeze(-1).repeat(1, 1, patch_size**2 *3)  # [B, H*W] -> [B, H*W, p*p*3]
+            # 将序列格式的mask逆转回二维图像格式
+            mask = unpatchify(mask, patch_size)
+            mask = mask[0].permute(1, 2, 0).cpu().numpy()
+            # 掩盖图像中被遮掩的图像patch区域
+            masked_img = org_img * (1 - mask)  # 1 is removing, 0 is keeping
+            masked_img = masked_img.astype(np.uint8)
+
+            # 将序列格式的重构图像逆转回二维图像格式
+            pred_img = unpatchify(output['x_pred'], patch_size)
+            pred_img = pred_img[0].permute(1, 2, 0).cpu().numpy()
+            pred_img = (pred_img * pixel_std + pixel_mean) * 255.
+            # 将原图中被保留的图像patch和网络预测的重构的图像patch拼在一起
+            pred_img = org_img * (1 - mask) + pred_img * mask
+            pred_img = pred_img.astype(np.uint8)
+
+            # visualize
+            vis_image = np.concatenate([masked_img, org_img, pred_img], axis=1)
+            vis_image = vis_image[..., (2, 1, 0)]
+            cv2.imshow('masked | origin | reconstruct ', vis_image)
+            cv2.waitKey(0)
+
+            # save
+            cv2.imwrite('{}/{:06}.png'.format(save_path, i), vis_image)
+
+
+if __name__ == "__main__":
+    main()

+ 9 - 0
masked_image_modeling/models/__init__.py

@@ -0,0 +1,9 @@
+from .vit.build import build_vision_transformer
+
+
+def build_model(args, model_type='default'):
+    # ----------- Vision Transformer -----------
+    if "vit" in args.model:
+        return build_vision_transformer(args, model_type)
+    else:
+        raise NotImplementedError("Unknown model: {}".format(args.model))

+ 3 - 0
masked_image_modeling/models/vit/__init__.py

@@ -0,0 +1,3 @@
+from .vit import build_vit
+from .vit_mae import build_vit_mae
+from .vit_cls import ViTForImageClassification

+ 45 - 0
masked_image_modeling/models/vit/build.py

@@ -0,0 +1,45 @@
+import os
+import torch
+
+from .vit     import build_vit
+from .vit_mae import build_vit_mae
+from .vit_cls import ViTForImageClassification
+
+
+def build_vision_transformer(args, model_type='default'):
+    assert args.model in ['vit_t', 'vit_s', 'vit_b', 'vit_l', 'vit_h'], "Unknown vit model: {}".format(args.model)
+
+    # ----------- Masked Image Modeling task -----------
+    if model_type == 'mae':
+        model = build_vit_mae(args.model, args.img_size, args.patch_size, args.img_dim, args.mask_ratio)
+    
+    # ----------- Image Classification task -----------
+    elif model_type == 'cls':
+        image_encoder = build_vit(args.model, args.img_size, args.patch_size, args.img_dim)
+        model = ViTForImageClassification(image_encoder, num_classes=args.num_classes, qkv_bias=True)
+        load_mae_pretrained(model.encoder, args.pretrained)
+
+    # ----------- Vison Backbone -----------
+    elif model_type == 'default':
+        model = build_vit(args.model, args.img_size, args.patch_size, args.img_dim)
+        load_mae_pretrained(model, args.pretrained)
+        
+    else:
+        raise NotImplementedError("Unknown model type: {}".format(model_type))
+    
+    return model
+
+
+def load_mae_pretrained(model, ckpt=None):
+    if ckpt is not None:
+        # check path
+        if not os.path.exists(ckpt):
+            print("No pretrained model.")
+            return model
+        print('- Loading pretrained from: {}'.format(ckpt))
+        checkpoint = torch.load(ckpt, map_location='cpu')
+        # checkpoint state dict
+        encoder_state_dict = checkpoint.pop("encoder")
+
+        # load encoder weight into ViT's encoder
+        model.load_state_dict(encoder_state_dict)

+ 186 - 0
masked_image_modeling/models/vit/modules.py

@@ -0,0 +1,186 @@
+# --------------------------------------------------------------------
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from typing import Type
+
+
+# ----------------------- Basic modules -----------------------
+class FeedFroward(nn.Module):
+    def __init__(self,
+                 embedding_dim: int,
+                 mlp_dim: int,
+                 act: Type[nn.Module] = nn.GELU,
+                 dropout: float = 0.0,
+                 ) -> None:
+        super().__init__()
+        self.fc1   = nn.Linear(embedding_dim, mlp_dim)
+        self.drop1 = nn.Dropout(dropout)
+        self.fc2   = nn.Linear(mlp_dim, embedding_dim)
+        self.drop2 = nn.Dropout(dropout)
+        self.act   = act()
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop1(x)
+        x = self.fc2(x)
+        x = self.drop2(x)
+        return x
+
+class PatchEmbed(nn.Module):
+    def __init__(self,
+                 in_chans    : int = 3,
+                 embed_dim   : int = 768,
+                 kernel_size : int = 16,
+                 padding     : int = 0,
+                 stride      : int = 16,
+                 ) -> None:
+        super().__init__()
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.proj(x)
+
+
+# ----------------------- Model modules -----------------------
+class ViTBlock(nn.Module):
+    def __init__(self,
+                 dim       :int,
+                 num_heads :int,
+                 mlp_ratio :float = 4.0,
+                 qkv_bias  :bool = True,
+                 act_layer :Type[nn.Module] = nn.GELU,
+                 dropout   :float = 0.
+                 ) -> None:
+        super().__init__()
+        # -------------- Model parameters --------------
+        self.norm1 = nn.LayerNorm(dim)
+        self.attn  = Attention(dim         = dim,
+                               qkv_bias    = qkv_bias,
+                               num_heads   = num_heads,
+                               dropout     = dropout
+                               )
+        self.norm2 = nn.LayerNorm(dim)
+        self.ffn   = FeedFroward(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        shortcut = x
+        # Attention (with prenorm)
+        x = self.norm1(x)
+        x = self.attn(x)
+        x = shortcut + x
+
+        # Feedforward (with prenorm)
+        x = x + self.ffn(self.norm2(x))
+
+        return x
+
+class Attention(nn.Module):
+    def __init__(self,
+                 dim       :int,
+                 qkv_bias  :bool  = False,
+                 num_heads :int   = 8,
+                 dropout   :float = 0.
+                 ):
+        super().__init__()
+        # --------------- Basic parameters ---------------
+        self.dim = dim
+        self.num_heads = num_heads
+        self.head_dim = dim // num_heads
+        self.scale = self.head_dim ** -0.5
+
+        # --------------- Network parameters ---------------
+        self.qkv_proj = nn.Linear(dim, dim*3, bias = qkv_bias)
+        self.attn_drop = nn.Dropout(dropout)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(dropout)
+
+    def forward(self, x):
+        bs, N, _ = x.shape
+        # ----------------- Input proj -----------------
+        qkv = self.qkv_proj(x)
+        q, k, v = torch.chunk(qkv, 3, dim=-1)
+
+        # ----------------- Multi-head Attn -----------------
+        ## [B, N, C] -> [B, N, H, C_h] -> [B, H, N, C_h]
+        q = q.view(bs, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
+        k = k.view(bs, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
+        v = v.view(bs, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
+        ## [B, H, Nq, C_h] X [B, H, C_h, Nk] = [B, H, Nq, Nk]
+        attn = q * self.scale @ k.transpose(-1, -2)
+        attn = attn.softmax(dim=-1)
+        attn = self.attn_drop(attn)
+        x = attn @ v # [B, H, Nq, C_h]
+
+        # ----------------- Output -----------------
+        x = x.permute(0, 2, 1, 3).contiguous().view(bs, N, -1)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+
+        return x
+
+
+# ----------------------- Classifier -----------------------
+class AttentionPoolingClassifier(nn.Module):
+    def __init__(
+        self,
+        in_dim      : int,
+        out_dim     : int,
+        num_heads   : int = 12,
+        qkv_bias    : bool = False,
+        num_queries : int = 1,
+    ):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = in_dim // num_heads
+        self.scale = head_dim**-0.5
+
+        self.k = nn.Linear(in_dim, in_dim, bias=qkv_bias)
+        self.v = nn.Linear(in_dim, in_dim, bias=qkv_bias)
+
+        self.cls_token = nn.Parameter(torch.randn(1, num_queries, in_dim) * 0.02)
+        self.linear = nn.Linear(in_dim, out_dim)
+        self.bn = nn.BatchNorm1d(in_dim, affine=False, eps=1e-6)
+
+        self.num_queries = num_queries
+
+    def forward(self, x: torch.Tensor):
+        B, N, C = x.shape
+
+        x = self.bn(x.transpose(-2, -1)).transpose(-2, -1)
+        cls_token = self.cls_token.expand(B, -1, -1)  # newly created class token
+
+        q = cls_token.reshape(
+            B, self.num_queries, self.num_heads, C // self.num_heads
+        ).permute(0, 2, 1, 3)
+        k = (
+            self.k(x)
+            .reshape(B, N, self.num_heads, C // self.num_heads)
+            .permute(0, 2, 1, 3)
+        )
+
+        q = q * self.scale
+        v = (
+            self.v(x)
+            .reshape(B, N, self.num_heads, C // self.num_heads)
+            .permute(0, 2, 1, 3)
+        )
+
+        attn = q @ k.transpose(-2, -1)
+        attn = attn.softmax(dim=-1)
+
+        x_cls = (attn @ v).transpose(1, 2).reshape(B, self.num_queries, C)
+        x_cls = x_cls.mean(dim=1)
+
+        out = self.linear(x_cls)
+
+        return out, x_cls

+ 96 - 0
masked_image_modeling/models/vit/pos_embed.py

@@ -0,0 +1,96 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# Position embedding utils
+# --------------------------------------------------------
+
+import numpy as np
+
+import torch
+
+# --------------------------------------------------------
+# 2D sine-cosine position embedding
+# References:
+# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
+# MoCo v3: https://github.com/facebookresearch/moco-v3
+# --------------------------------------------------------
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
+    """
+    grid_size: int of the grid height and width
+    return:
+    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+    """
+    grid_h = np.arange(grid_size, dtype=np.float32)
+    grid_w = np.arange(grid_size, dtype=np.float32)
+    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
+    grid = np.stack(grid, axis=0)
+
+    grid = grid.reshape([2, 1, grid_size, grid_size])
+    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+    if cls_token:
+        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+    return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+    assert embed_dim % 2 == 0
+
+    # use half of dimensions to encode grid_h
+    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
+    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)
+
+    emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+    return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+    """
+    embed_dim: output dimension for each position
+    pos: a list of positions to be encoded: size (M,)
+    out: (M, D)
+    """
+    assert embed_dim % 2 == 0
+    omega = np.arange(embed_dim // 2, dtype=np.float)
+    omega /= embed_dim / 2.
+    omega = 1. / 10000**omega  # (D/2,)
+
+    pos = pos.reshape(-1)  # (M,)
+    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product
+
+    emb_sin = np.sin(out) # (M, D/2)
+    emb_cos = np.cos(out) # (M, D/2)
+
+    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
+    return emb
+
+
+# --------------------------------------------------------
+# Interpolate position embeddings for high-resolution
+# References:
+# DeiT: https://github.com/facebookresearch/deit
+# --------------------------------------------------------
+def interpolate_pos_embed(model, checkpoint_model):
+    if 'pos_embed' in checkpoint_model:
+        pos_embed_checkpoint = checkpoint_model['pos_embed']
+        embedding_size = pos_embed_checkpoint.shape[-1]
+        num_patches = model.num_patches
+        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+        # height (== width) for the checkpoint position embedding
+        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+        # height (== width) for the new position embedding
+        new_size = int(num_patches ** 0.5)
+        # class_token and dist_token are kept unchanged
+        if orig_size != new_size:
+            print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
+            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+            # only the position tokens are interpolated
+            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+            pos_tokens = torch.nn.functional.interpolate(
+                pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+            checkpoint_model['pos_embed'] = new_pos_embed

+ 180 - 0
masked_image_modeling/models/vit/vit.py

@@ -0,0 +1,180 @@
+# --------------------------------------------------------------------
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------------------
+
+import torch
+import torch.nn as nn
+
+try:
+    from .modules import PatchEmbed, ViTBlock
+except:
+    from  modules import PatchEmbed, ViTBlock
+
+
+# ---------------------- Vision transformer ----------------------
+class ImageEncoderViT(nn.Module):
+    def __init__(self,
+                 img_size: int,
+                 patch_size: int,
+                 in_chans: int,
+                 patch_embed_dim: int,
+                 depth: int,
+                 num_heads: int,
+                 mlp_ratio: float,
+                 act_layer: nn.GELU,
+                 dropout: float = 0.0,
+                 ) -> None:
+        super().__init__()
+        # ----------- Basic parameters -----------
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.image_embedding_size = img_size // ((patch_size if patch_size > 0 else 1))
+        self.patch_embed_dim = patch_embed_dim
+        self.num_heads = num_heads
+        self.num_patches = (img_size // patch_size) ** 2
+        # ----------- Model parameters -----------
+        self.patch_embed = PatchEmbed(in_chans, patch_embed_dim, patch_size, stride=patch_size)
+        self.pos_embed   = nn.Parameter(torch.zeros(1, self.num_patches, patch_embed_dim))
+        self.norm_layer  = nn.LayerNorm(patch_embed_dim)
+        self.blocks      = nn.ModuleList([
+            ViTBlock(patch_embed_dim, num_heads, mlp_ratio, True, act_layer, dropout)
+            for _ in range(depth)])
+
+        self._init_weights()
+
+    def _init_weights(self):
+        # initialize (and freeze) pos_embed by sin-cos embedding
+        pos_embed = self.get_posembed(self.pos_embed.shape[-1], int(self.num_patches**.5))
+        self.pos_embed.data.copy_(pos_embed)
+
+        # initialize nn.Linear and nn.LayerNorm
+        for m in self.modules():           
+            if isinstance(m, nn.Linear):
+                # we use xavier_uniform following official JAX ViT:
+                torch.nn.init.xavier_uniform_(m.weight)
+                if isinstance(m, nn.Linear) and m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.LayerNorm):
+                nn.init.constant_(m.bias, 0)
+                nn.init.constant_(m.weight, 1.0)
+
+    def get_posembed(self, embed_dim, grid_size, temperature=10000):
+        scale = 2 * torch.pi
+        grid_h, grid_w = grid_size, grid_size
+        num_pos_feats = embed_dim // 2
+        # get grid
+        y_embed, x_embed = torch.meshgrid([torch.arange(grid_h, dtype=torch.float32),
+                                           torch.arange(grid_w, dtype=torch.float32)])
+        # normalize grid coords
+        y_embed = y_embed / (grid_h + 1e-6) * scale
+        x_embed = x_embed / (grid_w + 1e-6) * scale
+    
+        dim_t = torch.arange(num_pos_feats, dtype=torch.float32)
+        dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
+        dim_t = temperature ** (2 * dim_t_)
+
+        pos_x = torch.div(x_embed[..., None], dim_t)
+        pos_y = torch.div(y_embed[..., None], dim_t)
+        pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
+        pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
+
+        # [H, W, C] -> [N, C]
+        pos_embed = torch.cat((pos_y, pos_x), dim=-1).view(-1, embed_dim)
+
+        return pos_embed.unsqueeze(0)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        # Patch embed
+        x = self.patch_embed(x)
+        x = x.flatten(2).permute(0, 2, 1).contiguous()
+
+        # Add pos embed
+        x = x + self.pos_embed
+
+        # Apply Transformer blocks
+        for block in self.blocks:
+            x = block(x)
+        x = self.norm_layer(x)
+
+        return x
+
+
+# ------------------------ Model Functions ------------------------
+def build_vit(model_name="vit_t", img_size=224, patch_size=16, img_dim=3):
+    if model_name == "vit_t":
+        return ImageEncoderViT(img_size=img_size,
+                               patch_size=patch_size,
+                               in_chans=img_dim,
+                               patch_embed_dim=192,
+                               depth=12,
+                               num_heads=3,
+                               mlp_ratio=4.0,
+                               act_layer=nn.GELU,
+                               dropout = 0.1)
+    if model_name == "vit_s":
+        return ImageEncoderViT(img_size=img_size,
+                               patch_size=patch_size,
+                               in_chans=img_dim,
+                               patch_embed_dim=384,
+                               depth=12,
+                               num_heads=6,
+                               mlp_ratio=4.0,
+                               act_layer=nn.GELU,
+                               dropout = 0.1)
+    if model_name == "vit_b":
+        return ImageEncoderViT(img_size=img_size,
+                               patch_size=patch_size,
+                               in_chans=img_dim,
+                               patch_embed_dim=768,
+                               depth=12,
+                               num_heads=12,
+                               mlp_ratio=4.0,
+                               act_layer=nn.GELU,
+                               dropout = 0.1)
+    if model_name == "vit_l":
+        return ImageEncoderViT(img_size=img_size,
+                               patch_size=patch_size,
+                               in_chans=img_dim,
+                               patch_embed_dim=1024,
+                               depth=24,
+                               num_heads=16,
+                               mlp_ratio=4.0,
+                               act_layer=nn.GELU,
+                               dropout = 0.1)
+    if model_name == "vit_h":
+        return ImageEncoderViT(img_size=img_size,
+                               patch_size=patch_size,
+                               in_chans=img_dim,
+                               patch_embed_dim=1280,
+                               depth=32,
+                               num_heads=16,
+                               mlp_ratio=4.0,
+                               act_layer=nn.GELU,
+                               dropout = 0.1)
+    
+
+if __name__ == '__main__':
+    import torch
+    from thop import profile
+
+    # Prepare an image as the input
+    bs, c, h, w = 2, 3, 224, 224
+    x = torch.randn(bs, c, h, w)
+    patch_size = 16
+
+    # Build model
+    model = build_vit(patch_size=patch_size)
+
+    # Inference
+    outputs = model(x)
+
+    # Compute FLOPs & Params
+    print('==============================')
+    model.eval()
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 28 - 0
masked_image_modeling/models/vit/vit_cls.py

@@ -0,0 +1,28 @@
+import torch.nn as nn
+
+from .modules import AttentionPoolingClassifier
+from .vit     import ImageEncoderViT
+
+
+class ViTForImageClassification(nn.Module):
+    def __init__(self,
+                 image_encoder :ImageEncoderViT,
+                 num_classes   :int   = 1000,
+                 qkv_bias      :bool  = True,
+                 ):
+        super().__init__()
+        # -------- Model parameters --------
+        self.encoder    = image_encoder
+        self.classifier = AttentionPoolingClassifier(
+            image_encoder.patch_embed_dim, num_classes, image_encoder.num_heads, qkv_bias, num_queries=1)
+
+    def forward(self, x):
+        """
+        Inputs:
+            x: (torch.Tensor) -> [B, C, H, W]. Input image.
+        """
+        x = self.encoder(x)
+        x, x_cls = self.classifier(x)
+
+        return x
+

+ 399 - 0
masked_image_modeling/models/vit/vit_mae.py

@@ -0,0 +1,399 @@
+import math
+import torch
+import torch.nn as nn
+
+try:
+    from .modules import ViTBlock, PatchEmbed
+except:
+    from  modules import ViTBlock, PatchEmbed
+
+
+# ------------------------ Basic Modules ------------------------
+class MaeEncoder(nn.Module):
+    def __init__(self,
+                 img_size: int,
+                 patch_size: int,
+                 in_chans: int,
+                 patch_embed_dim: int,
+                 depth: int,
+                 num_heads: int,
+                 mlp_ratio: float,
+                 act_layer: nn.GELU,
+                 mask_ratio: float = 0.75,
+                 dropout: float = 0.0,
+                 ) -> None:
+        super().__init__()
+        # ----------- Basic parameters -----------
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.image_embedding_size = img_size // ((patch_size if patch_size > 0 else 1))
+        self.patch_embed_dim = patch_embed_dim
+        self.num_heads = num_heads
+        self.num_patches = (img_size // patch_size) ** 2
+        self.mask_ratio = mask_ratio
+        # ----------- Model parameters -----------
+        self.patch_embed = PatchEmbed(in_chans, patch_embed_dim, patch_size, 0, patch_size)
+        self.pos_embed   = nn.Parameter(torch.zeros(1, self.num_patches, patch_embed_dim), requires_grad=False)
+        self.norm_layer  = nn.LayerNorm(patch_embed_dim)
+        self.blocks      = nn.ModuleList([
+            ViTBlock(patch_embed_dim, num_heads, mlp_ratio, True, act_layer=act_layer, dropout=dropout)
+            for _ in range(depth)])
+        self._init_weights()
+
+    def _init_weights(self):
+        # initialize (and freeze) pos_embed by sin-cos embedding
+        pos_embed = self.get_posembed(self.pos_embed.shape[-1], int(self.num_patches**.5))
+        self.pos_embed.data.copy_(pos_embed)
+
+        # initialize nn.Linear and nn.LayerNorm
+        for m in self.modules():           
+            if isinstance(m, nn.Linear):
+                # we use xavier_uniform following official JAX ViT:
+                torch.nn.init.xavier_uniform_(m.weight)
+                if isinstance(m, nn.Linear) and m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.LayerNorm):
+                nn.init.constant_(m.bias, 0)
+                nn.init.constant_(m.weight, 1.0)
+
+    def get_posembed(self, embed_dim, grid_size, temperature=10000):
+        scale = 2 * math.pi
+        grid_h, grid_w = grid_size, grid_size
+        num_pos_feats = embed_dim // 2
+        # get grid
+        y_embed, x_embed = torch.meshgrid([torch.arange(grid_h, dtype=torch.float32),
+                                           torch.arange(grid_w, dtype=torch.float32)])
+        # normalize grid coords
+        y_embed = y_embed / (grid_h + 1e-6) * scale
+        x_embed = x_embed / (grid_w + 1e-6) * scale
+    
+        dim_t = torch.arange(num_pos_feats, dtype=torch.float32)
+        dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
+        dim_t = temperature ** (2 * dim_t_)
+
+        pos_x = torch.div(x_embed[..., None], dim_t)
+        pos_y = torch.div(y_embed[..., None], dim_t)
+        pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
+        pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
+
+        # [H, W, C] -> [N, C]
+        pos_embed = torch.cat((pos_y, pos_x), dim=-1).view(-1, embed_dim)
+
+        return pos_embed.unsqueeze(0)
+
+    def random_masking(self, x):
+        B, N, C = x.shape
+        len_keep = int(N * (1 - self.mask_ratio))
+
+        noise = torch.rand(B, N, device=x.device)  # noise in [0, 1]
+
+        # sort noise for each sample
+        ids_shuffle = torch.argsort(noise, dim=1)        # ascend: small is keep, large is remove
+        ids_restore = torch.argsort(ids_shuffle, dim=1)  # restore the original position of each patch
+
+        # keep the first subset
+        ids_keep = ids_shuffle[:, :len_keep]
+        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, C))
+
+        # generate the binary mask: 0 is keep, 1 is remove
+        mask = torch.ones([B, N], device=x.device)
+        mask[:, :len_keep] = 0
+
+        # unshuffle to get th binary mask
+        mask = torch.gather(mask, dim=1, index=ids_restore)
+
+        return x_masked, mask, ids_restore
+    
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        # patch embed
+        x = self.patch_embed(x)
+        # [B, C, H, W] -> [B, C, N] -> [B, N, C], N = H x W
+        x = x.flatten(2).permute(0, 2, 1).contiguous()
+
+        # add pos embed
+        x = x + self.pos_embed
+
+        # masking: length -> length * mask_ratio
+        x, mask, ids_restore = self.random_masking(x)
+
+        # apply Transformer blocks
+        for block in self.blocks:
+            x = block(x)
+        x = self.norm_layer(x)
+        
+        return x, mask, ids_restore
+
+class MaeDecoder(nn.Module):
+    def __init__(self,
+                 img_dim       :int   = 3,
+                 img_size      :int   = 16,
+                 patch_size    :int   = 16,
+                 en_emb_dim    :int   = 784,
+                 de_emb_dim    :int   = 512,
+                 de_num_layers :int   = 12,
+                 de_num_heads  :int   = 12,
+                 qkv_bias      :bool  = True,
+                 mlp_ratio     :float = 4.0,
+                 dropout       :float = 0.1,
+                 mask_ratio    :float = 0.75,
+                 ):
+        super().__init__()
+        # -------- basic parameters --------
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.num_patches = (img_size // patch_size) ** 2
+        self.en_emb_dim = en_emb_dim
+        self.de_emb_dim = de_emb_dim
+        self.de_num_layers = de_num_layers
+        self.de_num_heads = de_num_heads
+        self.mask_ratio = mask_ratio
+        # -------- network parameters --------
+        self.mask_token        = nn.Parameter(torch.zeros(1, 1, de_emb_dim))
+        self.decoder_embed     = nn.Linear(en_emb_dim, de_emb_dim)
+        self.mask_token        = nn.Parameter(torch.zeros(1, 1, de_emb_dim))
+        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, de_emb_dim), requires_grad=False)  # fixed sin-cos embedding
+        self.decoder_norm      = nn.LayerNorm(de_emb_dim)
+        self.decoder_pred      = nn.Linear(de_emb_dim, patch_size**2 * img_dim, bias=True)
+        self.blocks            = nn.ModuleList([
+            ViTBlock(de_emb_dim, de_num_heads, mlp_ratio, qkv_bias, dropout=dropout)
+            for _ in range(de_num_layers)])
+        
+        self._init_weights()
+
+    def _init_weights(self):
+        # initialize (and freeze) pos_embed by sin-cos embedding
+        decoder_pos_embed = self.get_posembed(self.decoder_pos_embed.shape[-1], int(self.num_patches**.5))
+        self.decoder_pos_embed.data.copy_(decoder_pos_embed)
+
+        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
+        torch.nn.init.normal_(self.mask_token, std=.02)
+
+        # initialize nn.Linear and nn.LayerNorm
+        for m in self.modules():           
+            if isinstance(m, nn.Linear):
+                # we use xavier_uniform following official JAX ViT:
+                torch.nn.init.xavier_uniform_(m.weight)
+                if isinstance(m, nn.Linear) and m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.LayerNorm):
+                nn.init.constant_(m.bias, 0)
+                nn.init.constant_(m.weight, 1.0)
+
+    def get_posembed(self, embed_dim, grid_size, temperature=10000):
+        scale = 2 * math.pi
+        grid_h, grid_w = grid_size, grid_size
+        num_pos_feats = embed_dim // 2
+        # get grid
+        y_embed, x_embed = torch.meshgrid([torch.arange(grid_h, dtype=torch.float32),
+                                           torch.arange(grid_w, dtype=torch.float32)])
+        # normalize grid coords
+        y_embed = y_embed / (grid_h + 1e-6) * scale
+        x_embed = x_embed / (grid_w + 1e-6) * scale
+    
+        dim_t = torch.arange(num_pos_feats, dtype=torch.float32)
+        dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
+        dim_t = temperature ** (2 * dim_t_)
+
+        pos_x = torch.div(x_embed[..., None], dim_t)
+        pos_y = torch.div(y_embed[..., None], dim_t)
+        pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
+        pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
+
+        # [H, W, C] -> [N, C]
+        pos_embed = torch.cat((pos_y, pos_x), dim=-1).view(-1, embed_dim)
+
+        return pos_embed.unsqueeze(0)
+
+    def forward(self, x_enc, ids_restore):
+        # embed tokens
+        x_enc = self.decoder_embed(x_enc)
+        B, N_nomask, C = x_enc.shape
+
+        # append mask tokens to sequence
+        mask_tokens = self.mask_token.repeat(B, ids_restore.shape[1] - N_nomask, 1)     # [B, N_mask, C], N_mask = (N-1) - N_nomask
+        x_all = torch.cat([x_enc, mask_tokens], dim=1)
+        x_all = torch.gather(x_all, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, C))  # unshuffle
+
+        # add pos embed
+        x_all = x_all + self.decoder_pos_embed
+
+        # apply Transformer blocks
+        for block in self.blocks:
+            x_all = block(x_all)
+        x_all = self.decoder_norm(x_all)
+
+        # predict
+        x_out = self.decoder_pred(x_all)
+
+        return x_out
+
+
+# ------------------------ MAE Vision Transformer ------------------------
+class ViTforMaskedAutoEncoder(nn.Module):
+    def __init__(self,
+                 encoder :MaeEncoder,
+                 decoder :MaeDecoder,
+                 ):
+        super().__init__()
+        self.mae_encoder = encoder
+        self.mae_decoder = decoder
+
+    def patchify(self, imgs, patch_size):
+        """
+        imgs: (B, 3, H, W)
+        x: (N, L, patch_size**2 *3)
+        """
+        p = patch_size
+        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
+
+        h = w = imgs.shape[2] // p
+        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
+        x = torch.einsum('nchpwq->nhwpqc', x)
+        x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
+
+        return x
+    
+    def unpatchify(self, x, patch_size):
+        """
+        x: (B, N, patch_size**2 *3)
+        imgs: (B, 3, H, W)
+        """
+        p = patch_size
+        h = w = int(x.shape[1]**.5)
+        assert h * w == x.shape[1]
+        
+        x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
+        x = torch.einsum('nhwpqc->nchpwq', x)
+        imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
+
+        return imgs
+
+    def compute_loss(self, x, output):
+        """
+        imgs: [B, 3, H, W]
+        pred: [B, N, C], C = p*p*3
+        mask: [B, N], 0 is keep, 1 is remove, 
+        """
+        target = self.patchify(x, self.mae_encoder.patch_size)
+        pred, mask = output["x_pred"], output["mask"]
+        loss = (pred - target) ** 2
+        loss = loss.mean(dim=-1)  # [B, N], mean loss per patch
+        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
+        
+        return loss
+
+    def forward(self, x):
+        imgs = x
+        x, mask, ids_restore = self.mae_encoder(x)
+        x = self.mae_decoder(x, ids_restore)
+        output = {
+            'x_pred': x,
+            'mask': mask
+        }
+
+        if self.training:
+            loss = self.compute_loss(imgs, output)
+            output["loss"] = loss
+
+        return output
+
+
+# ------------------------ Model Functions ------------------------
+def build_vit_mae(model_name="vit_t", img_size=224, patch_size=16, img_dim=3, mask_ratio=0.75):
+    # ---------------- MAE Encoder ----------------
+    if model_name == "vit_t":
+        encoder = MaeEncoder(img_size=img_size,
+                             patch_size=patch_size,
+                             in_chans=img_dim,
+                             patch_embed_dim=192,
+                             depth=12,
+                             num_heads=3,
+                             mlp_ratio=4.0,
+                             act_layer=nn.GELU,
+                             mask_ratio=mask_ratio,
+                             dropout = 0.1)
+    if model_name == "vit_s":
+        encoder = MaeEncoder(img_size=img_size,
+                             patch_size=patch_size,
+                             in_chans=img_dim,
+                             patch_embed_dim=384,
+                             depth=12,
+                             num_heads=6,
+                             mlp_ratio=4.0,
+                             act_layer=nn.GELU,
+                             mask_ratio=mask_ratio,
+                             dropout = 0.1)
+    if model_name == "vit_b":
+        encoder = MaeEncoder(img_size=img_size,
+                             patch_size=patch_size,
+                             in_chans=img_dim,
+                             patch_embed_dim=768,
+                             depth=12,
+                             num_heads=12,
+                             mlp_ratio=4.0,
+                             act_layer=nn.GELU,
+                             mask_ratio=mask_ratio,
+                             dropout = 0.1)
+    if model_name == "vit_l":
+        encoder = MaeEncoder(img_size=img_size,
+                             patch_size=patch_size,
+                             in_chans=img_dim,
+                             patch_embed_dim=1024,
+                             depth=24,
+                             num_heads=16,
+                             mlp_ratio=4.0,
+                             act_layer=nn.GELU,
+                             mask_ratio=mask_ratio,
+                             dropout = 0.1)
+    if model_name == "vit_h":
+        encoder = MaeEncoder(img_size=img_size,
+                             patch_size=patch_size,
+                             in_chans=img_dim,
+                             patch_embed_dim=1280,
+                             depth=32,
+                             num_heads=16,
+                             mlp_ratio=4.0,
+                             act_layer=nn.GELU,
+                             mask_ratio=mask_ratio,
+                             dropout = 0.1)
+    
+    # ---------------- MAE Decoder ----------------
+    decoder = MaeDecoder(img_dim = img_dim,
+                         img_size=img_size,
+                         patch_size=patch_size,
+                         en_emb_dim=encoder.patch_embed_dim,
+                         de_emb_dim=512,
+                         de_num_layers=8,
+                         de_num_heads=16,
+                         qkv_bias=True,
+                         mlp_ratio=4.0,
+                         mask_ratio=mask_ratio,
+                         dropout=0.1,)
+    
+    return ViTforMaskedAutoEncoder(encoder, decoder)
+
+
+if __name__ == '__main__':
+    import torch
+    from thop import profile
+
+    # Prepare an image as the input
+    bs, c, h, w = 2, 3, 224, 224
+    x = torch.randn(bs, c, h, w)
+    patch_size = 16
+
+    # Build model
+    model = build_vit_mae(patch_size=patch_size)
+
+    # Inference
+    outputs = model(x)
+    if "loss" in outputs:
+        print("Loss: ", outputs["loss"].item())
+
+    # Compute FLOPs & Params
+    print('==============================')
+    model.eval()
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))
+

+ 5 - 0
masked_image_modeling/requirements.txt

@@ -0,0 +1,5 @@
+torch
+torchvision
+opencv-python
+thop
+timm

+ 0 - 0
masked_image_modeling/utils/__init__.py


+ 1 - 3
iclab/utils/lr_scheduler.py → masked_image_modeling/utils/lr_scheduler.py

@@ -10,9 +10,7 @@ class LinearWarmUpLrScheduler(object):
 
     def set_lr(self, optimizer, cur_lr):
         for param_group in optimizer.param_groups:
-            init_lr = param_group['initial_lr']
-            ratio = init_lr / self.base_lr
-            param_group['lr'] = cur_lr * ratio
+            param_group['lr'] = cur_lr
 
     def __call__(self, iter, optimizer):
         # warmup

+ 50 - 110
iclab/utils/misc.py → masked_image_modeling/utils/misc.py

@@ -1,34 +1,13 @@
 import time
+import torch
 import numpy as np
 import random
 import datetime
 from collections import defaultdict, deque
 from pathlib import Path
 
-import torch
-import torch.nn as nn
-import torch.distributed as dist
-
-from .distributed_utils import get_world_size, is_main_process, is_dist_avail_and_initialized
-
 
 # ---------------------- Common functions ----------------------
-def all_reduce_mean(x):
-    world_size = get_world_size()
-    if world_size > 1:
-        x_reduce = torch.tensor(x).cuda()
-        dist.all_reduce(x_reduce)
-        x_reduce /= world_size
-        return x_reduce.item()
-    else:
-        return x
-
-def print_rank_0(msg, rank=None):
-    if rank is not None and rank <= 0:
-        print(msg)
-    elif is_main_process():
-        print(msg)
-
 def setup_seed(seed=42):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
@@ -36,10 +15,6 @@ def setup_seed(seed=42):
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
 
-def is_parallel(model):
-    # Returns True if model is of type DP or DDP
-    return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
-
 def accuracy(output, target, topk=(1,)):
     """Computes the accuracy over the k top predictions for the specified values of k"""
     with torch.no_grad():
@@ -60,7 +35,6 @@ class SmoothedValue(object):
     """Track a series of values and provide access to smoothed values over a
     window or the global series average.
     """
-
     def __init__(self, window_size=20, fmt=None):
         if fmt is None:
             fmt = "{median:.4f} ({global_avg:.4f})"
@@ -74,19 +48,6 @@ class SmoothedValue(object):
         self.count += n
         self.total += value * n
 
-    def synchronize_between_processes(self):
-        """
-        Warning: does not synchronize the deque!
-        """
-        if not is_dist_avail_and_initialized():
-            return
-        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
-        dist.barrier()
-        dist.all_reduce(t)
-        t = t.tolist()
-        self.count = int(t[0])
-        self.total = t[1]
-
     @property
     def median(self):
         d = torch.tensor(list(self.deque))
@@ -147,10 +108,6 @@ class MetricLogger(object):
             )
         return self.delimiter.join(loss_str)
 
-    def synchronize_between_processes(self):
-        for meter in self.meters.values():
-            meter.synchronize_between_processes()
-
     def add_meter(self, name, meter):
         self.meters[name] = meter
 
@@ -201,52 +158,8 @@ class MetricLogger(object):
             header, total_time_str, total_time / len(iterable)))
 
 
-# ---------------------- Optimize functions ----------------------
-def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
-    if isinstance(parameters, torch.Tensor):
-        parameters = [parameters]
-    parameters = [p for p in parameters if p.grad is not None]
-    norm_type = float(norm_type)
-    if len(parameters) == 0:
-        return torch.tensor(0.)
-    device = parameters[0].grad.device
-    total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device)
-                                        for p in parameters]),
-                            norm_type)
-
-    return total_norm
-
-class NativeScalerWithGradNormCount:
-    state_dict_key = "amp_scaler"
-
-    def __init__(self):
-        self._scaler = torch.cuda.amp.GradScaler()
-
-    def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
-        self._scaler.scale(loss).backward()
-        if update_grad:
-            if clip_grad is not None:
-                assert parameters is not None
-                self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
-                norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
-            else:
-                self._scaler.unscale_(optimizer)
-                norm = get_grad_norm_(parameters)
-            self._scaler.step(optimizer)
-            self._scaler.update()
-        else:
-            norm = None
-        return norm
-
-    def state_dict(self):
-        return self._scaler.state_dict()
-
-    def load_state_dict(self, state_dict):
-        self._scaler.load_state_dict(state_dict)
-
-
 # ---------------------- Model functions ----------------------
-def load_model(args, model_without_ddp, optimizer, lr_scheduler, loss_scaler):
+def load_model(args, model, optimizer, lr_scheduler):
     if args.resume and args.resume.lower() != 'none':
         print("=================== Load checkpoint ===================")
         if args.resume.startswith('https'):
@@ -254,38 +167,65 @@ def load_model(args, model_without_ddp, optimizer, lr_scheduler, loss_scaler):
                 args.resume, map_location='cpu', check_hash=True)
         else:
             checkpoint = torch.load(args.resume, map_location='cpu')
-        model_without_ddp.load_state_dict(checkpoint['model'])
+        model.load_state_dict(checkpoint['model'])
         print("Resume checkpoint %s" % args.resume)
         
         if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
             print('- Load optimizer from the checkpoint. ')
             optimizer.load_state_dict(checkpoint['optimizer'])
             args.start_epoch = checkpoint['epoch'] + 1
-            if 'scaler' in checkpoint:
-                loss_scaler.load_state_dict(checkpoint['scaler'])
 
         if 'lr_scheduler' in checkpoint:
             print('- Load lr scheduler from the checkpoint. ')
             lr_scheduler.load_state_dict(checkpoint.pop("lr_scheduler"))
 
-def save_model(args, epoch, model, model_without_ddp, optimizer, lr_scheduler, loss_scaler, acc1=None):
+def save_model(args, epoch, model, optimizer, lr_scheduler, acc1=None, mae_task=False):
     output_dir = Path(args.output_dir)
     epoch_name = str(epoch)
-    if loss_scaler is not None:
-        if acc1 is not None:
-            checkpoint_paths = [output_dir / ('checkpoint-{}-Acc1-{:.2f}.pth'.format(epoch_name, acc1))]
-        else:
-            checkpoint_paths = [output_dir / ('checkpoint-{}.pth'.format(epoch_name))]
-        for checkpoint_path in checkpoint_paths:
-            to_save = {
-                'model': model_without_ddp.state_dict(),
-                'optimizer': optimizer.state_dict(),
-                'lr_scheduler': lr_scheduler.state_dict(),
-                'epoch': epoch,
-                'scaler': loss_scaler.state_dict(),
-                'args': args,
-            }
-            torch.save(to_save, checkpoint_path)
+    if acc1 is not None:
+        checkpoint_paths = [output_dir / ('checkpoint-{}-Acc1-{:.2f}.pth'.format(epoch_name, acc1))]
     else:
-        client_state = {'epoch': epoch}
-        model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
+        checkpoint_paths = [output_dir / ('checkpoint-{}.pth'.format(epoch_name))]
+    for checkpoint_path in checkpoint_paths:
+        to_save = {
+            'model': model.state_dict(),
+            'optimizer': optimizer.state_dict(),
+            'lr_scheduler': lr_scheduler.state_dict(),
+            'epoch': epoch,
+            'args': args,
+        }
+        if mae_task:
+            to_save['encoder'] = model.mae_encoder.state_dict()
+        torch.save(to_save, checkpoint_path)
+
+
+# ---------------------- Patch operations ----------------------
+def patchify(imgs, patch_size):
+    """
+    imgs: (B, 3, H, W)
+    x: (N, L, patch_size**2 *3)
+    """
+    p = patch_size
+    assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
+
+    h = w = imgs.shape[2] // p
+    x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
+    x = torch.einsum('nchpwq->nhwpqc', x)
+    x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
+
+    return x
+
+def unpatchify(x, patch_size):
+    """
+    x: (B, N, patch_size**2 *3)
+    imgs: (B, 3, H, W)
+    """
+    p = patch_size
+    h = w = int(x.shape[1]**.5)
+    assert h * w == x.shape[1]
+    
+    x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
+    x = torch.einsum('nhwpqc->nchpwq', x)
+    imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
+
+    return imgs

+ 25 - 0
masked_image_modeling/utils/optimizer.py

@@ -0,0 +1,25 @@
+import torch
+
+
+def build_optimizer(args, model):
+    ## learning rate
+    if args.optimizer == "adamw":
+        args.base_lr = args.base_lr / 256 * args.batch_size
+        optimizer = torch.optim.AdamW(model.parameters(),
+                                      lr=args.base_lr,
+                                      weight_decay=args.weight_decay)
+    elif args.optimizer == "sgd":
+        args.base_lr = args.base_lr / 256 * args.batch_size
+        optimizer = torch.optim.SGD(model.parameters(),
+                                    lr=args.base_lr,
+                                    momentum=0.9,
+                                    weight_decay=args.weight_decay)
+    else:
+        raise NotImplementedError("Unknown optimizer: {}".format(args.optimizer))
+
+    print("=================== Optimizer information ===================")
+    print("Optimizer: ", args.optimizer)
+    print('- base lr: ', args.base_lr)
+    print('- min  lr: ', args.min_lr)
+
+    return optimizer

+ 0 - 3
yolo/config/__init__.py

@@ -6,7 +6,6 @@ from .yolov5_config     import build_yolov5_config
 from .yolov5_af_config  import build_yolov5af_config
 from .yolov6_config     import build_yolov6_config
 from .yolov8_config     import build_yolov8_config
-from .yolov8_e2e_config import build_yolov8_e2e_config
 from .gelan_config      import build_gelan_config
 from .rtdetr_config     import build_rtdetr_config
 
@@ -27,8 +26,6 @@ def build_config(args):
         cfg = build_yolov5_config(args)
     elif 'yolov6' in args.model:
         cfg = build_yolov6_config(args)
-    elif 'yolov8_e2e' in args.model:
-        cfg = build_yolov8_e2e_config(args)
     elif 'yolov8' in args.model:
         cfg = build_yolov8_config(args)
     elif 'gelan' in args.model:

+ 0 - 196
yolo/config/yolov8_e2e_config.py

@@ -1,196 +0,0 @@
-# yolo Config
-
-
-def build_yolov8_e2e_config(args):
-    if   args.model == 'yolov8_e2e_n':
-        return Yolov8E2E_N_Config()
-    elif args.model == 'yolov8_e2e_s':
-        return Yolov8E2E_S_Config()
-    elif args.model == 'yolov8_e2e_m':
-        return Yolov8E2E_M_Config()
-    elif args.model == 'yolov8_e2e_l':
-        return Yolov8E2E_L_Config()
-    elif args.model == 'yolov8_e2e_x':
-        return Yolov8E2E_X_Config()
-    else:
-        raise NotImplementedError("No config for model: {}".format(args.model))
-    
-# YOLOv8-E2E Base config
-class Yolov8E2EBaseConfig(object):
-    def __init__(self) -> None:
-        # ---------------- Model config ----------------
-        self.width    = 1.0
-        self.depth    = 1.0
-        self.ratio    = 1.0
-        self.reg_max  = 16
-        self.out_stride = [8, 16, 32]
-        self.max_stride = 32
-        self.num_levels = 3
-        self.scale      = "b"
-        ## Backbone
-        self.bk_act   = 'silu'
-        self.bk_norm  = 'BN'
-        self.bk_depthwise = False
-        self.use_pretrained = True
-        ## Neck
-        self.neck_act       = 'silu'
-        self.neck_norm      = 'BN'
-        self.neck_depthwise = False
-        self.neck_expand_ratio = 0.5
-        self.spp_pooling_size  = 5
-        ## FPN
-        self.fpn_act  = 'silu'
-        self.fpn_norm = 'BN'
-        self.fpn_depthwise = False
-        ## Head
-        self.head_act  = 'silu'
-        self.head_norm = 'BN'
-        self.head_depthwise = False
-        self.num_cls_head   = 2
-        self.num_reg_head   = 2
-
-        # ---------------- Post-process config ----------------
-        ## Post process
-        self.val_topk = 100
-        self.val_conf_thresh = 0.001
-        self.test_topk = 100
-        self.test_conf_thresh = 0.2
-
-        # ---------------- Assignment config ----------------
-        ## Matcher
-        self.tal_topk_candidates = 10
-        self.tal_alpha = 0.5
-        self.tal_beta  = 6.0
-        ## Loss weight
-        self.loss_cls = 0.5
-        self.loss_box = 7.5
-        self.loss_dfl = 1.5
-
-        # ---------------- ModelEMA config ----------------
-        self.use_ema = True
-        self.ema_decay = 0.9998
-        self.ema_tau   = 2000
-
-        # ---------------- Optimizer config ----------------
-        self.trainer      = 'yolo'
-        self.optimizer    = 'adamw'
-        self.base_lr      = 0.001     # base_lr = per_image_lr * batch_size
-        self.min_lr_ratio = 0.01      # min_lr  = base_lr * min_lr_ratio
-        self.batch_size_base = 64
-        self.momentum     = 0.9
-        self.weight_decay = 0.05
-        self.clip_max_norm   = 35.0
-        self.warmup_bias_lr  = 0.1
-        self.warmup_momentum = 0.8
-
-        # ---------------- Lr Scheduler config ----------------
-        self.warmup_epoch = 3
-        self.lr_scheduler = "cosine"
-        self.max_epoch    = 500
-        self.eval_epoch   = 10
-        self.no_aug_epoch = 20
-
-        # ---------------- Data process config ----------------
-        self.aug_type = 'yolo'
-        self.box_format = 'xyxy'
-        self.normalize_coords = False
-        self.mosaic_prob = 0.0
-        self.mixup_prob  = 0.0
-        self.copy_paste  = 0.0           # approximated by the YOLOX's mixup
-        self.multi_scale = [0.5, 1.5]   # multi scale: [img_size * 0.5, img_size * 1.5]
-        ## Pixel mean & std
-        self.pixel_mean = [0., 0., 0.]
-        self.pixel_std  = [255., 255., 255.]
-        ## Transforms
-        self.train_img_size = 640
-        self.test_img_size  = 640
-        self.affine_params = {
-            'degrees': 0.0,
-            'translate': 0.2,
-            'scale': [0.1, 2.0],
-            'shear': 0.0,
-            'perspective': 0.0,
-            'hsv_h': 0.015,
-            'hsv_s': 0.7,
-            'hsv_v': 0.4,
-        }
-
-    def print_config(self):
-        config_dict = {key: value for key, value in self.__dict__.items() if not key.startswith('__')}
-        for k, v in config_dict.items():
-            print("{} : {}".format(k, v))
-
-# YOLOv8-E2E N
-class Yolov8E2E_N_Config(Yolov8E2EBaseConfig):
-    def __init__(self) -> None:
-        super().__init__()
-        # ---------------- Model config ----------------
-        self.width = 0.25
-        self.depth = 0.34
-        self.ratio = 2.0
-        self.scale = "n"
-
-        # ---------------- Data process config ----------------
-        self.mosaic_prob = 1.0
-        self.mixup_prob  = 0.0
-        self.copy_paste  = 0.5
-
-# YOLOv8-S
-class Yolov8E2E_S_Config(Yolov8E2EBaseConfig):
-    def __init__(self) -> None:
-        super().__init__()
-        # ---------------- Model config ----------------
-        self.width = 0.50
-        self.depth = 0.34
-        self.ratio = 2.0
-        self.scale = "s"
-
-        # ---------------- Data process config ----------------
-        self.mosaic_prob = 1.0
-        self.mixup_prob  = 0.0
-        self.copy_paste  = 0.5
-
-# YOLOv8-M
-class Yolov8E2E_M_Config(Yolov8E2EBaseConfig):
-    def __init__(self) -> None:
-        super().__init__()
-        # ---------------- Model config ----------------
-        self.width = 0.75
-        self.depth = 0.67
-        self.ratio = 1.5
-        self.scale = "m"
-
-        # ---------------- Data process config ----------------
-        self.mosaic_prob = 1.0
-        self.mixup_prob  = 0.1
-        self.copy_paste  = 0.5
-
-# YOLOv8-L
-class Yolov8E2E_L_Config(Yolov8E2EBaseConfig):
-    def __init__(self) -> None:
-        super().__init__()
-        # ---------------- Model config ----------------
-        self.width = 1.0
-        self.depth = 1.0
-        self.ratio = 1.0
-        self.scale = "l"
-
-        # ---------------- Data process config ----------------
-        self.mosaic_prob = 1.0
-        self.mixup_prob  = 0.1
-        self.copy_paste  = 0.5
-
-# YOLOv8-X
-class Yolov8E2E_X_Config(Yolov8E2EBaseConfig):
-    def __init__(self) -> None:
-        super().__init__()
-        # ---------------- Model config ----------------
-        self.width = 1.25
-        self.depth = 1.0
-        self.ratio = 1.0
-        self.scale = "x"
-
-        # ---------------- Data process config ----------------
-        self.mosaic_prob = 1.0
-        self.mixup_prob  = 0.1
-        self.copy_paste  = 0.5

+ 0 - 4
yolo/models/__init__.py

@@ -9,7 +9,6 @@ from .yolov5.build     import build_yolov5
 from .yolov5_af.build  import build_yolov5af
 from .yolov6.build     import build_yolov6
 from .yolov8.build     import build_yolov8
-from .yolov8_e2e.build import build_yolov8_e2e
 from .gelan.build      import build_gelan
 from .rtdetr.build     import build_rtdetr
 
@@ -36,9 +35,6 @@ def build_model(args, cfg, is_val=False):
     elif 'yolov6' in args.model:
         model, criterion = build_yolov6(cfg, is_val)
     ## YOLOv8
-    elif 'yolov8_e2e' in args.model:
-        model, criterion = build_yolov8_e2e(cfg, is_val)
-    ## YOLOv8
     elif 'yolov8' in args.model:
         model, criterion = build_yolov8(cfg, is_val)
     ## GElan

+ 0 - 60
yolo/models/yolov8_e2e/README.md

@@ -1,60 +0,0 @@
-# End-to-End YOLOv8:
-
-Inspired by YOLOv10, I deploy two parallel detection heads, one using one-to-many assinger (o2m head) and the other using one-to-one assinger (o2o head). To avoid conflicts between the gradients returned by o2o head and o2m head, we truncate the gradients returned from o2o head to the backbone and neck, and only allow the gradients returned from o2m head to update the backbone and neck. This operation is consistent with the practice of YOLOv10. For evaluation, we remove the o2m head and only use o2o head without NMS.
-
-However, I have no GPU to train YOLOv8-E2E.
-
-- VOC
-
-|     Model   | Batch | Scale | AP<sup>val<br>0.5 | Weight |  Logs  |
-|-------------|-------|-------|-------------------|--------|--------|
-| YOLOv8-E2E-S    | 1xb16 |  640  |               |  |  |
-
-- COCO
-
-|    Model    | Batch | Scale | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |  Logs  |
-|-------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|--------|
-| YOLOv8-E2E-S    | 1xb16 |  640  |                    |               |   26.9            |   8.9             |  |  |
-
-
-
-## Train YOLOv8-E2E
-### Single GPU
-Taking training YOLOv8-E2E-S on COCO as the example,
-```Shell
-python train.py --cuda -d coco --root path/to/coco -m yolov8_e2e_s -bs 16 --fp16 
-```
-
-### Multi GPU
-Taking training YOLOv8-E2E-S on COCO as the example,
-```Shell
-python -m torch.distributed.run --nproc_per_node=8 train.py --cuda --distributed -d coco --root path/to/coco -m yolov8_e2e_s -bs 256 --fp16 
-```
-
-## Test YOLOv8
-Taking testing YOLOv8-E2E-S on COCO-val as the example,
-```Shell
-python test.py --cuda -d coco --root path/to/coco -m yolov8_e2e_s --weight path/to/yolov8.pth --show 
-```
-
-## Evaluate YOLOv8
-Taking evaluating YOLOv8-E2E-S on COCO-val as the example,
-```Shell
-python eval.py --cuda -d coco --root path/to/coco -m yolov8_e2e_s --weight path/to/yolov8.pth 
-```
-
-## Demo
-### Detect with Image
-```Shell
-python demo.py --mode image --path_to_img path/to/image_dirs/ --cuda -m yolov8_e2e_s --weight path/to/weight --show
-```
-
-### Detect with Video
-```Shell
-python demo.py --mode video --path_to_vid path/to/video --cuda -m yolov8_e2e_s --weight path/to/weight --show --gif
-```
-
-### Detect with Camera
-```Shell
-python demo.py --mode camera --cuda -m yolov8_e2e_s --weight path/to/weight --show --gif
-```

+ 0 - 24
yolo/models/yolov8_e2e/build.py

@@ -1,24 +0,0 @@
-import torch.nn as nn
-
-from .loss import SetCriterion
-from .yolov8_e2e import Yolov8E2E
-
-
-# build object detector
-def build_yolov8_e2e(cfg, is_val=False):
-    # -------------- Build YOLO --------------
-    model = Yolov8E2E(cfg, is_val)
-
-    # -------------- Initialize YOLO --------------
-    for m in model.modules():
-        if isinstance(m, nn.BatchNorm2d):
-            m.eps = 1e-3
-            m.momentum = 0.03    
-            
-    # -------------- Build criterion --------------
-    criterion = None
-    if is_val:
-        # build criterion for training
-        criterion = SetCriterion(cfg)
-        
-    return model, criterion

+ 0 - 204
yolo/models/yolov8_e2e/loss.py

@@ -1,204 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from utils.box_ops import bbox2dist, bbox_iou
-from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
-
-from .matcher import TaskAlignedAssigner
-
-
-class SetCriterion(object):
-    def __init__(self, cfg):
-        # --------------- Basic parameters ---------------
-        self.cfg = cfg
-        self.reg_max = cfg.reg_max
-        self.num_classes = cfg.num_classes
-        # --------------- Loss config ---------------
-        self.loss_cls_weight = cfg.loss_cls
-        self.loss_box_weight = cfg.loss_box
-        self.loss_dfl_weight = cfg.loss_dfl
-        # --------------- Matcher config ---------------
-        self.matcher = TaskAlignedAssigner(num_classes     = cfg.num_classes,
-                                           topk_candidates = cfg.tal_topk_candidates,
-                                           alpha           = cfg.tal_alpha,
-                                           beta            = cfg.tal_beta
-                                           )
-
-    def loss_classes(self, pred_cls, gt_score):
-        # compute bce loss
-        loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_score, reduction='none')
-
-        return loss_cls
-    
-    def loss_bboxes(self, pred_box, gt_box, bbox_weight):
-        # regression loss
-        ious = bbox_iou(pred_box, gt_box, xywh=False, CIoU=True)
-        loss_box = (1.0 - ious.squeeze(-1)) * bbox_weight
-
-        return loss_box
-    
-    def loss_dfl(self, pred_reg, gt_box, anchor, stride, bbox_weight=None):
-        # rescale coords by stride
-        gt_box_s = gt_box / stride
-        anchor_s = anchor / stride
-
-        # compute deltas
-        gt_ltrb_s = bbox2dist(anchor_s, gt_box_s, self.reg_max - 1)
-
-        gt_left = gt_ltrb_s.to(torch.long)
-        gt_right = gt_left + 1
-
-        weight_left = gt_right.to(torch.float) - gt_ltrb_s
-        weight_right = 1 - weight_left
-
-        # loss left
-        loss_left = F.cross_entropy(
-            pred_reg.view(-1, self.reg_max),
-            gt_left.view(-1),
-            reduction='none').view(gt_left.shape) * weight_left
-        # loss right
-        loss_right = F.cross_entropy(
-            pred_reg.view(-1, self.reg_max),
-            gt_right.view(-1),
-            reduction='none').view(gt_left.shape) * weight_right
-
-        loss_dfl = (loss_left + loss_right).mean(-1)
-        
-        if bbox_weight is not None:
-            loss_dfl *= bbox_weight
-
-        return loss_dfl
-
-    def compute_loss(self, outputs, targets):
-        """
-            outputs['pred_cls']: List(Tensor) [B, M, C]
-            outputs['pred_reg']: List(Tensor) [B, M, 4*(reg_max+1)]
-            outputs['pred_box']: List(Tensor) [B, M, 4]
-            outputs['anchors']: List(Tensor) [M, 2]
-            outputs['strides']: List(Int) [8, 16, 32] output stride
-            outputs['stride_tensor']: List(Tensor) [M, 1]
-            targets: (List) [dict{'boxes': [...], 
-                                 'labels': [...], 
-                                 'orig_size': ...}, ...]
-        """
-        # preds: [B, M, C]
-        cls_preds = torch.cat(outputs['pred_cls'], dim=1)
-        reg_preds = torch.cat(outputs['pred_reg'], dim=1)
-        box_preds = torch.cat(outputs['pred_box'], dim=1)
-        bs, num_anchors = cls_preds.shape[:2]
-        device = cls_preds.device
-        anchors = torch.cat(outputs['anchors'], dim=0)
-        
-        # --------------- label assignment ---------------
-        gt_score_targets = []
-        gt_bbox_targets = []
-        fg_masks = []
-        for batch_idx in range(bs):
-            tgt_labels = targets[batch_idx]["labels"].to(device)     # [Mp,]
-            tgt_boxs = targets[batch_idx]["boxes"].to(device)        # [Mp, 4]
-
-            if self.cfg.normalize_coords:
-                img_h, img_w = outputs['image_size']
-                tgt_boxs[..., [0, 2]] *= img_w
-                tgt_boxs[..., [1, 3]] *= img_h
-            
-            if self.cfg.box_format == 'xywh':
-                tgt_boxs_x1y1 = tgt_boxs[..., :2] - 0.5 * tgt_boxs[..., 2:]
-                tgt_boxs_x2y2 = tgt_boxs[..., :2] + 0.5 * tgt_boxs[..., 2:]
-                tgt_boxs = torch.cat([tgt_boxs_x1y1, tgt_boxs_x2y2], dim=-1)
-
-            # check target
-            if len(tgt_labels) == 0 or tgt_boxs.max().item() == 0.:
-                # There is no valid gt
-                fg_mask  = cls_preds.new_zeros(1, num_anchors).bool()               #[1, M,]
-                gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
-                gt_box   = cls_preds.new_zeros((1, num_anchors, 4))                  #[1, M, 4]
-            else:
-                tgt_labels = tgt_labels[None, :, None]      # [1, Mp, 1]
-                tgt_boxs = tgt_boxs[None]                   # [1, Mp, 4]
-                (
-                    _,
-                    gt_box,     # [1, M, 4]
-                    gt_score,   # [1, M, C]
-                    fg_mask,    # [1, M,]
-                    _
-                ) = self.matcher(
-                    pd_scores = cls_preds[batch_idx:batch_idx+1].detach().sigmoid(), 
-                    pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
-                    anc_points = anchors,
-                    gt_labels = tgt_labels,
-                    gt_bboxes = tgt_boxs
-                    )
-            gt_score_targets.append(gt_score)
-            gt_bbox_targets.append(gt_box)
-            fg_masks.append(fg_mask)
-
-        # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
-        fg_masks = torch.cat(fg_masks, 0).view(-1)                                    # [BM,]
-        gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes)  # [BM, C]
-        gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4)                   # [BM, 4]
-        num_fgs = gt_score_targets.sum()
-        
-        # Average loss normalizer across all the GPUs
-        if is_dist_avail_and_initialized():
-            torch.distributed.all_reduce(num_fgs)
-        num_fgs = (num_fgs / get_world_size()).clamp(1.0)
-
-        # ------------------ Classification loss ------------------
-        cls_preds = cls_preds.view(-1, self.num_classes)
-        loss_cls = self.loss_classes(cls_preds, gt_score_targets)
-        loss_cls = loss_cls.sum() / num_fgs
-
-        # ------------------ Regression loss ------------------
-        box_preds_pos = box_preds.view(-1, 4)[fg_masks]
-        box_targets_pos = gt_bbox_targets.view(-1, 4)[fg_masks]
-        bbox_weight = gt_score_targets[fg_masks].sum(-1)
-        loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos, bbox_weight)
-        loss_box = loss_box.sum() / num_fgs
-
-        # ------------------ Distribution focal loss  ------------------
-        ## process anchors
-        anchors = anchors[None].repeat(bs, 1, 1).view(-1, 2)
-        ## process stride tensors
-        strides = torch.cat(outputs['stride_tensor'], dim=0)
-        strides = strides.unsqueeze(0).repeat(bs, 1, 1).view(-1, 1)
-        ## fg preds
-        reg_preds_pos = reg_preds.view(-1, 4*self.reg_max)[fg_masks]
-        anchors_pos = anchors[fg_masks]
-        strides_pos = strides[fg_masks]
-        ## compute dfl
-        loss_dfl = self.loss_dfl(reg_preds_pos, box_targets_pos, anchors_pos, strides_pos, bbox_weight)
-        loss_dfl = loss_dfl.sum() / num_fgs
-
-        # total loss
-        losses = loss_cls * self.loss_cls_weight + \
-                 loss_box * self.loss_box_weight + \
-                 loss_dfl * self.loss_dfl_weight
-        loss_dict = dict(
-                loss_cls = loss_cls,
-                loss_box = loss_box,
-                loss_dfl = loss_dfl,
-                losses = losses
-        )
-
-        return loss_dict
-    
-    def __call__(self, outputs, targets):
-        self.matcher.topk_candidates = self.cfg.tal_topk_candidates
-        o2m_loss_dict = self.compute_loss(outputs["outputs_o2m"], targets)
-
-        self.matcher.topk_candidates = 1
-        o2o_loss_dict = self.compute_loss(outputs["outputs_o2o"], targets)
-
-        loss_dict = {}
-        loss_dict["losses"] = o2o_loss_dict["losses"] + o2m_loss_dict["losses"]
-        for k in o2m_loss_dict:
-            loss_dict['o2m_' + k] = o2m_loss_dict[k]
-        for k in o2o_loss_dict:
-            loss_dict['o2o_' + k] = o2o_loss_dict[k]
-
-        return loss_dict
-
-if __name__ == "__main__":
-    pass

+ 0 - 202
yolo/models/yolov8_e2e/matcher.py

@@ -1,202 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from utils.box_ops import bbox_iou
-
-
-# -------------------------- Task Aligned Assigner --------------------------
-class TaskAlignedAssigner(nn.Module):
-    """
-        This code referenced to https://github.com/ultralytics/ultralytics
-    """
-    def __init__(self,
-                 num_classes     = 80,
-                 topk_candidates = 10,
-                 alpha           = 0.5,
-                 beta            = 6.0, 
-                 eps             = 1e-9):
-        super(TaskAlignedAssigner, self).__init__()
-        self.topk_candidates = topk_candidates
-        self.num_classes = num_classes
-        self.bg_idx = num_classes
-        self.alpha = alpha
-        self.beta = beta
-        self.eps = eps
-
-    @torch.no_grad()
-    def forward(self,
-                pd_scores,
-                pd_bboxes,
-                anc_points,
-                gt_labels,
-                gt_bboxes):
-        self.bs = pd_scores.size(0)
-        self.n_max_boxes = gt_bboxes.size(1)
-
-        mask_pos, align_metric, overlaps = self.get_pos_mask(
-            pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points)
-
-        target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(
-            mask_pos, overlaps, self.n_max_boxes)
-
-        # Assigned target
-        target_labels, target_bboxes, target_scores = self.get_targets(
-            gt_labels, gt_bboxes, target_gt_idx, fg_mask)
-
-        # normalize
-        align_metric *= mask_pos
-        pos_align_metrics = align_metric.amax(axis=-1, keepdim=True)  # b, max_num_obj
-        pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True)  # b, max_num_obj
-        norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
-        target_scores = target_scores * norm_align_metric
-
-        return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
-
-    def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points):
-        # get in_gts mask, (b, max_num_obj, h*w)
-        mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
-        # get anchor_align metric, (b, max_num_obj, h*w)
-        align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts)
-        # get topk_metric mask, (b, max_num_obj, h*w)
-        mask_topk = self.select_topk_candidates(align_metric)
-        # merge all mask to a final mask, (b, max_num_obj, h*w)
-        mask_pos = mask_topk * mask_in_gts
-
-        return mask_pos, align_metric, overlaps
-
-    def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts):
-        """Compute alignment metric given predicted and ground truth bounding boxes."""
-        na = pd_bboxes.shape[-2]
-        mask_in_gts = mask_in_gts.bool()  # b, max_num_obj, h*w
-        overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
-        bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
-
-        ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
-        ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes)  # b, max_num_obj
-        ind[1] = gt_labels.squeeze(-1)  # b, max_num_obj
-        # Get the scores of each grid for each gt cls
-        bbox_scores[mask_in_gts] = pd_scores[ind[0], :, ind[1]][mask_in_gts]  # b, max_num_obj, h*w
-
-        # (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
-        pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_in_gts]
-        gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_in_gts]
-        overlaps[mask_in_gts] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
-
-        align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
-        return align_metric, overlaps
-
-    def select_topk_candidates(self, metrics, largest=True):
-        """
-        Args:
-            metrics: (b, max_num_obj, h*w).
-            topk_mask: (b, max_num_obj, topk) or None
-        """
-        # (b, max_num_obj, topk)
-        topk_metrics, topk_idxs = torch.topk(metrics, self.topk_candidates, dim=-1, largest=largest)
-        topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
-        # (b, max_num_obj, topk)
-        topk_idxs.masked_fill_(~topk_mask, 0)
-
-        # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
-        count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
-        ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
-        for k in range(self.topk_candidates):
-            # Expand topk_idxs for each value of k and add 1 at the specified positions
-            count_tensor.scatter_add_(-1, topk_idxs[:, :, k:k + 1], ones)
-        # count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
-        # Filter invalid bboxes
-        count_tensor.masked_fill_(count_tensor > 1, 0)
-
-        return count_tensor.to(metrics.dtype)
-
-    def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
-        # Assigned target labels, (b, 1)
-        batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
-        target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes  # (b, h*w)
-        target_labels = gt_labels.long().flatten()[target_gt_idx]  # (b, h*w)
-
-        # Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
-        target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
-
-        # Assigned target scores
-        target_labels.clamp_(0)
-
-        # 10x faster than F.one_hot()
-        target_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.num_classes),
-                                    dtype=torch.int64,
-                                    device=target_labels.device)  # (b, h*w, 80)
-        target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
-
-        fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)  # (b, h*w, 80)
-        target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
-
-        return target_labels, target_bboxes, target_scores
-    
-
-# -------------------------- Basic Functions --------------------------
-def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
-    """select the positive anchors's center in gt
-    Args:
-        xy_centers (Tensor): shape(bs*n_max_boxes, num_total_anchors, 4)
-        gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
-    Return:
-        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-    """
-    n_anchors = xy_centers.size(0)
-    bs, n_max_boxes, _ = gt_bboxes.size()
-    _gt_bboxes = gt_bboxes.reshape([-1, 4])
-    xy_centers = xy_centers.unsqueeze(0).repeat(bs * n_max_boxes, 1, 1)
-    gt_bboxes_lt = _gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, n_anchors, 1)
-    gt_bboxes_rb = _gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, n_anchors, 1)
-    b_lt = xy_centers - gt_bboxes_lt
-    b_rb = gt_bboxes_rb - xy_centers
-    bbox_deltas = torch.cat([b_lt, b_rb], dim=-1)
-    bbox_deltas = bbox_deltas.reshape([bs, n_max_boxes, n_anchors, -1])
-    return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)
-
-def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
-    """if an anchor box is assigned to multiple gts,
-        the one with the highest iou will be selected.
-    Args:
-        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-        overlaps (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-    Return:
-        target_gt_idx (Tensor): shape(bs, num_total_anchors)
-        fg_mask (Tensor): shape(bs, num_total_anchors)
-        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-    """
-    fg_mask = mask_pos.sum(-2)
-    if fg_mask.max() > 1:  # one anchor is assigned to multiple gt_bboxes
-        mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1)  # (b, n_max_boxes, h*w)
-        max_overlaps_idx = overlaps.argmax(1)  # (b, h*w)
-
-        is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
-        is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
-
-        mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()  # (b, n_max_boxes, h*w)
-        fg_mask = mask_pos.sum(-2)
-    # Find each grid serve which gt(index)
-    target_gt_idx = mask_pos.argmax(-2)  # (b, h*w)
-
-    return target_gt_idx, fg_mask, mask_pos
-
-def iou_calculator(box1, box2, eps=1e-9):
-    """Calculate iou for batch
-    Args:
-        box1 (Tensor): shape(bs, n_max_boxes, 1, 4)
-        box2 (Tensor): shape(bs, 1, num_total_anchors, 4)
-    Return:
-        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-    """
-    box1 = box1.unsqueeze(2)  # [N, M1, 4] -> [N, M1, 1, 4]
-    box2 = box2.unsqueeze(1)  # [N, M2, 4] -> [N, 1, M2, 4]
-    px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
-    gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
-    x1y1 = torch.maximum(px1y1, gx1y1)
-    x2y2 = torch.minimum(px2y2, gx2y2)
-    overlap = (x2y2 - x1y1).clip(0).prod(-1)
-    area1 = (px2y2 - px1y1).clip(0).prod(-1)
-    area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
-    union = area1 + area2 - overlap + eps
-
-    return overlap / union

+ 0 - 181
yolo/models/yolov8_e2e/yolov8_backbone.py

@@ -1,181 +0,0 @@
-import torch
-import torch.nn as nn
-
-try:
-    from .yolov8_basic import BasicConv, ELANLayer
-except:
-    from  yolov8_basic import BasicConv, ELANLayer
-
-# IN1K pretrained weight
-pretrained_urls = {
-    'n': "https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/in1k_pretrained_weight/elandarknet_n_in1k_62.1.pth",
-    's': "https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/in1k_pretrained_weight/elandarknet_s_in1k_71.3.pth",
-    'm': "https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/in1k_pretrained_weight/elandarknet_m_in1k_75.7.pth",
-    'l': "https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/in1k_pretrained_weight/elandarknet_l_in1k_77.3.pth",
-    'x': None,
-}
-
-# ---------------------------- Basic functions ----------------------------
-class Yolov8Backbone(nn.Module):
-    def __init__(self, cfg):
-        super(Yolov8Backbone, self).__init__()
-        # ------------------ Basic setting ------------------
-        self.model_scale = cfg.scale
-        self.feat_dims = [round(64  * cfg.width),
-                          round(128 * cfg.width),
-                          round(256 * cfg.width),
-                          round(512 * cfg.width),
-                          round(512 * cfg.width * cfg.ratio)]
-        
-        # ------------------ Network setting ------------------
-        ## P1/2
-        self.layer_1 = BasicConv(3, self.feat_dims[0],
-                                 kernel_size=3, padding=1, stride=2,
-                                 act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise)
-        # P2/4
-        self.layer_2 = nn.Sequential(
-            BasicConv(self.feat_dims[0], self.feat_dims[1],
-                      kernel_size=3, padding=1, stride=2,
-                      act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
-            ELANLayer(in_dim     = self.feat_dims[1],
-                      out_dim    = self.feat_dims[1],
-                      num_blocks = round(3*cfg.depth),
-                      expansion  = 0.5,
-                      shortcut   = True,
-                      act_type   = cfg.bk_act,
-                      norm_type  = cfg.bk_norm,
-                      depthwise  = cfg.bk_depthwise)
-        )
-        # P3/8
-        self.layer_3 = nn.Sequential(
-            BasicConv(self.feat_dims[1], self.feat_dims[2],
-                      kernel_size=3, padding=1, stride=2,
-                      act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
-            ELANLayer(in_dim     = self.feat_dims[2],
-                      out_dim    = self.feat_dims[2],
-                      num_blocks = round(6*cfg.depth),
-                      expansion  = 0.5,
-                      shortcut   = True,
-                      act_type   = cfg.bk_act,
-                      norm_type  = cfg.bk_norm,
-                      depthwise  = cfg.bk_depthwise)
-        )
-        # P4/16
-        self.layer_4 = nn.Sequential(
-            BasicConv(self.feat_dims[2], self.feat_dims[3],
-                      kernel_size=3, padding=1, stride=2,
-                      act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
-            ELANLayer(in_dim     = self.feat_dims[3],
-                      out_dim    = self.feat_dims[3],
-                      num_blocks = round(6*cfg.depth),
-                      expansion  = 0.5,
-                      shortcut   = True,
-                      act_type   = cfg.bk_act,
-                      norm_type  = cfg.bk_norm,
-                      depthwise  = cfg.bk_depthwise)
-        )
-        # P5/32
-        self.layer_5 = nn.Sequential(
-            BasicConv(self.feat_dims[3], self.feat_dims[4],
-                      kernel_size=3, padding=1, stride=2,
-                      act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
-            ELANLayer(in_dim     = self.feat_dims[4],
-                      out_dim    = self.feat_dims[4],
-                      num_blocks = round(3*cfg.depth),
-                      expansion  = 0.5,
-                      shortcut   = True,
-                      act_type   = cfg.bk_act,
-                      norm_type  = cfg.bk_norm,
-                      depthwise  = cfg.bk_depthwise)
-        )
-
-        # Initialize all layers
-        self.init_weights()
-        
-        # Load imagenet pretrained weight
-        if cfg.use_pretrained:
-            self.load_pretrained()
-        
-    def init_weights(self):
-        """Initialize the parameters."""
-        for m in self.modules():
-            if isinstance(m, torch.nn.Conv2d):
-                # In order to be consistent with the source code,
-                # reset the Conv2d initialization parameters
-                m.reset_parameters()
-
-    def load_pretrained(self):
-        url = pretrained_urls[self.model_scale]
-        if url is not None:
-            print('Loading backbone pretrained weight from : {}'.format(url))
-            # checkpoint state dict
-            checkpoint = torch.hub.load_state_dict_from_url(
-                url=url, map_location="cpu", check_hash=True)
-            checkpoint_state_dict = checkpoint.pop("model")
-            # model state dict
-            model_state_dict = self.state_dict()
-            # check
-            for k in list(checkpoint_state_dict.keys()):
-                if k in model_state_dict:
-                    shape_model = tuple(model_state_dict[k].shape)
-                    shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
-                    if shape_model != shape_checkpoint:
-                        checkpoint_state_dict.pop(k)
-                else:
-                    checkpoint_state_dict.pop(k)
-                    print('Unused key: ', k)
-            # load the weight
-            self.load_state_dict(checkpoint_state_dict)
-        else:
-            print('No pretrained weight for model scale: {}.'.format(self.model_scale))
-
-    def forward(self, x):
-        c1 = self.layer_1(x)
-        c2 = self.layer_2(c1)
-        c3 = self.layer_3(c2)
-        c4 = self.layer_4(c3)
-        c5 = self.layer_5(c4)
-        outputs = [c3, c4, c5]
-
-        return outputs
-
-
-# ---------------------------- Functions ----------------------------
-## build Yolo's Backbone
-def build_backbone(cfg): 
-    # model
-    backbone = Yolov8Backbone(cfg)
-        
-    return backbone
-
-
-if __name__ == '__main__':
-    import time
-    from thop import profile
-    class BaseConfig(object):
-        def __init__(self) -> None:
-            self.bk_act = 'silu'
-            self.bk_norm = 'BN'
-            self.bk_depthwise = False
-            self.use_pretrained = True
-            self.width = 0.50
-            self.depth = 0.34
-            self.ratio = 2.0
-            self.scale = "s"
-
-    cfg = BaseConfig()
-    model = build_backbone(cfg)
-    x = torch.randn(1, 3, 640, 640)
-    t0 = time.time()
-    outputs = model(x)
-    t1 = time.time()
-    print('Time: ', t1 - t0)
-    for out in outputs:
-        print(out.shape)
-
-    x = torch.randn(1, 3, 640, 640)
-    print('==============================')
-    flops, params = profile(model, inputs=(x, ), verbose=False)
-    print('==============================')
-    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))

+ 0 - 172
yolo/models/yolov8_e2e/yolov8_basic.py

@@ -1,172 +0,0 @@
-import torch
-import torch.nn as nn
-from typing import List
-
-
-# --------------------- Basic modules ---------------------
-def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
-    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
-
-    return conv
-
-def get_activation(act_type=None):
-    if act_type == 'relu':
-        return nn.ReLU(inplace=True)
-    elif act_type == 'lrelu':
-        return nn.LeakyReLU(0.1, inplace=True)
-    elif act_type == 'mish':
-        return nn.Mish(inplace=True)
-    elif act_type == 'silu':
-        return nn.SiLU(inplace=True)
-    elif act_type is None:
-        return nn.Identity()
-    else:
-        raise NotImplementedError
-        
-def get_norm(norm_type, dim):
-    if norm_type == 'BN':
-        return nn.BatchNorm2d(dim)
-    elif norm_type == 'GN':
-        return nn.GroupNorm(num_groups=32, num_channels=dim)
-    elif norm_type is None:
-        return nn.Identity()
-    else:
-        raise NotImplementedError
-
-class BasicConv(nn.Module):
-    def __init__(self, 
-                 in_dim,                   # in channels
-                 out_dim,                  # out channels 
-                 kernel_size=1,            # kernel size 
-                 padding=0,                # padding
-                 stride=1,                 # padding
-                 dilation=1,               # dilation
-                 act_type  :str = 'lrelu', # activation
-                 norm_type :str = 'BN',    # normalization
-                 depthwise :bool = False
-                ):
-        super(BasicConv, self).__init__()
-        self.depthwise = depthwise
-        use_bias = False if norm_type is not None else True
-        if not depthwise:
-            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=1, bias=use_bias)
-            self.norm = get_norm(norm_type, out_dim)
-        else:
-            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim, bias=use_bias)
-            self.norm1 = get_norm(norm_type, in_dim)
-            self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1)
-            self.norm2 = get_norm(norm_type, out_dim)
-        self.act  = get_activation(act_type)
-
-    def forward(self, x):
-        if not self.depthwise:
-            return self.act(self.norm(self.conv(x)))
-        else:
-            # Depthwise conv
-            x = self.norm1(self.conv1(x))
-            # Pointwise conv
-            x = self.act(self.norm2(self.conv2(x)))
-            return x
-
-
-# --------------------- Yolov8 modules ---------------------
-class YoloBottleneck(nn.Module):
-    def __init__(self,
-                 in_dim      :int,
-                 out_dim     :int,
-                 kernel_size :List  = [1, 3],
-                 expansion   :float = 0.5,
-                 shortcut    :bool  = False,
-                 act_type    :str   = 'silu',
-                 norm_type   :str   = 'BN',
-                 depthwise   :bool  = False,
-                 ) -> None:
-        super(YoloBottleneck, self).__init__()
-        inter_dim = int(out_dim * expansion)
-        # ----------------- Network setting -----------------
-        self.conv_layer1 = BasicConv(in_dim, inter_dim,
-                                     kernel_size=kernel_size[0], padding=kernel_size[0]//2, stride=1,
-                                     act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        self.conv_layer2 = BasicConv(inter_dim, out_dim,
-                                     kernel_size=kernel_size[1], padding=kernel_size[1]//2, stride=1,
-                                     act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        self.shortcut = shortcut and in_dim == out_dim
-
-    def forward(self, x):
-        h = self.conv_layer2(self.conv_layer1(x))
-
-        return x + h if self.shortcut else h
-
-class CSPLayer(nn.Module):
-    # CSP Bottleneck with 3 convolutions
-    def __init__(self,
-                 in_dim      :int,
-                 out_dim     :int,
-                 num_blocks  :int   = 1,
-                 kernel_size :List = [3, 3],
-                 expansion   :float = 0.5,
-                 shortcut    :bool  = True,
-                 act_type    :str   = 'silu',
-                 norm_type   :str   = 'BN',
-                 depthwise   :bool  = False,
-                 ) -> None:
-        super().__init__()
-        inter_dim = round(out_dim * expansion)
-        self.input_proj_1 = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        self.input_proj_2 = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        self.output_proj  = BasicConv(2 * inter_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        self.module       = nn.Sequential(*[YoloBottleneck(inter_dim,
-                                                           inter_dim,
-                                                           kernel_size,
-                                                           expansion   = 1.0,
-                                                           shortcut    = shortcut,
-                                                           act_type    = act_type,
-                                                           norm_type   = norm_type,
-                                                           depthwise   = depthwise,
-                                                           ) for _ in range(num_blocks)])
-
-    def forward(self, x):
-        x1 = self.input_proj_1(x)
-        x2 = self.input_proj_2(x)
-        x2 = self.module(x2)
-        out = self.output_proj(torch.cat([x1, x2], dim=1))
-
-        return out
-
-class ELANLayer(nn.Module):
-    def __init__(self,
-                 in_dim,
-                 out_dim,
-                 expansion  :float = 0.5,
-                 num_blocks :int   = 1,
-                 shortcut   :bool  = False,
-                 act_type   :str   = 'silu',
-                 norm_type  :str   = 'BN',
-                 depthwise  :bool  = False,
-                 ) -> None:
-        super(ELANLayer, self).__init__()
-        inter_dim = round(out_dim * expansion)
-        self.input_proj  = BasicConv(in_dim, inter_dim * 2, kernel_size=1, act_type=act_type, norm_type=norm_type)
-        self.output_proj = BasicConv((2 + num_blocks) * inter_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
-        self.module      = nn.ModuleList([YoloBottleneck(inter_dim,
-                                                         inter_dim,
-                                                         kernel_size = [3, 3],
-                                                         expansion   = 1.0,
-                                                         shortcut    = shortcut,
-                                                         act_type    = act_type,
-                                                         norm_type   = norm_type,
-                                                         depthwise   = depthwise)
-                                                         for _ in range(num_blocks)])
-
-    def forward(self, x):
-        # Input proj
-        x1, x2 = torch.chunk(self.input_proj(x), 2, dim=1)
-        out = list([x1, x2])
-
-        # Bottlenecl
-        out.extend(m(out[-1]) for m in self.module)
-
-        # Output proj
-        out = self.output_proj(torch.cat(out, dim=1))
-
-        return out

+ 0 - 179
yolo/models/yolov8_e2e/yolov8_e2e.py

@@ -1,179 +0,0 @@
-# --------------- Torch components ---------------
-import copy
-import torch
-import torch.nn as nn
-
-# --------------- Model components ---------------
-from .yolov8_backbone import Yolov8Backbone
-from .yolov8_neck     import SPPF
-from .yolov8_pafpn    import Yolov8PaFPN
-from .yolov8_head     import Yolov8DetHead
-from .yolov8_pred     import Yolov8DetPredLayer
-
-
-# End-to-End YOLOv8
-class Yolov8E2E(nn.Module):
-    def __init__(self, cfg, is_val = False):
-        super(Yolov8E2E, self).__init__()
-        # ---------------------- Basic setting ----------------------
-        self.cfg = cfg
-        self.num_classes = cfg.num_classes
-        ## Post-process parameters
-        self.topk_candidates  = cfg.val_topk        if is_val else cfg.test_topk
-        self.conf_thresh      = cfg.val_conf_thresh if is_val else cfg.test_conf_thresh
-        self.no_multi_labels  = False if is_val else True
-        
-        # ---------------------- Model Parameters ----------------------
-        ## Backbone
-        self.backbone = Yolov8Backbone(cfg)
-        self.pyramid_feat_dims = self.backbone.feat_dims[-3:]
-        ## Neck
-        self.neck     = SPPF(cfg, self.pyramid_feat_dims[-1], self.pyramid_feat_dims[-1])
-        self.pyramid_feat_dims[-1] = self.neck.out_dim
-        ## Neck: PaFPN
-        self.fpn      = Yolov8PaFPN(cfg, self.backbone.feat_dims)
-        ## Head (one-to-one)
-        self.head_o2o = Yolov8DetHead(cfg, self.fpn.out_dims)
-        ## Pred (one-to-one)
-        self.pred_o2o = Yolov8DetPredLayer(cfg, self.head_o2o.cls_head_dim, self.head_o2o.reg_head_dim)
-
-        ## Aux head (one-to-many)
-        self.head_o2m = copy.deepcopy(self.head_o2o)
-        ## Aux Pred (one-to-many)
-        self.pred_o2m = copy.deepcopy(self.pred_o2o)
-
-    def post_process(self, cls_preds, box_preds):
-        """
-        We process predictions at each scale hierarchically
-        Input:
-            cls_preds: List[torch.Tensor] -> [[B, M, C], ...], B=1
-            box_preds: List[torch.Tensor] -> [[B, M, 4], ...], B=1
-        Output:
-            bboxes: np.array -> [N, 4]
-            scores: np.array -> [N,]
-            labels: np.array -> [N,]
-        """
-        all_scores = []
-        all_labels = []
-        all_bboxes = []
-        
-        for cls_pred_i, box_pred_i in zip(cls_preds, box_preds):
-            cls_pred_i = cls_pred_i[0]
-            box_pred_i = box_pred_i[0]
-            if self.no_multi_labels:
-                # [M,]
-                scores, labels = torch.max(cls_pred_i.sigmoid(), dim=1)
-
-                # Keep top k top scoring indices only.
-                num_topk = min(self.topk_candidates, box_pred_i.size(0))
-
-                # topk candidates
-                predicted_prob, topk_idxs = scores.sort(descending=True)
-                topk_scores = predicted_prob[:num_topk]
-                topk_idxs = topk_idxs[:num_topk]
-
-                # filter out the proposals with low confidence score
-                keep_idxs = topk_scores > self.conf_thresh
-                scores = topk_scores[keep_idxs]
-                topk_idxs = topk_idxs[keep_idxs]
-
-                labels = labels[topk_idxs]
-                bboxes = box_pred_i[topk_idxs]
-            else:
-                # [M, C] -> [MC,]
-                scores_i = cls_pred_i.sigmoid().flatten()
-
-                # Keep top k top scoring indices only.
-                num_topk = min(self.topk_candidates, box_pred_i.size(0))
-
-                # torch.sort is actually faster than .topk (at least on GPUs)
-                predicted_prob, topk_idxs = scores_i.sort(descending=True)
-                topk_scores = predicted_prob[:num_topk]
-                topk_idxs = topk_idxs[:num_topk]
-
-                # filter out the proposals with low confidence score
-                keep_idxs = topk_scores > self.conf_thresh
-                scores = topk_scores[keep_idxs]
-                topk_idxs = topk_idxs[keep_idxs]
-
-                anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
-                labels = topk_idxs % self.num_classes
-
-                bboxes = box_pred_i[anchor_idxs]
-
-            all_scores.append(scores)
-            all_labels.append(labels)
-            all_bboxes.append(bboxes)
-
-        scores = torch.cat(all_scores, dim=0)
-        labels = torch.cat(all_labels, dim=0)
-        bboxes = torch.cat(all_bboxes, dim=0)
-
-        # to cpu & numpy
-        scores = scores.cpu().numpy()
-        labels = labels.cpu().numpy()
-        bboxes = bboxes.cpu().numpy()
-
-        return bboxes, scores, labels
-    
-    def inference_o2o(self, x):
-        # ---------------- Backbone ----------------
-        pyramid_feats = self.backbone(x)
-        # ---------------- Neck: SPP ----------------
-        pyramid_feats[-1] = self.neck(pyramid_feats[-1])
-
-        # ---------------- Neck: PaFPN ----------------
-        pyramid_feats = self.fpn(pyramid_feats)
-
-        # ---------------- Heads ----------------
-        cls_feats, reg_feats = self.head_o2o(pyramid_feats)
-
-        # ---------------- Preds ----------------
-        outputs = self.pred_o2o(cls_feats, reg_feats)
-        outputs['image_size'] = [x.shape[2], x.shape[3]]
-
-        all_cls_preds = outputs['pred_cls']
-        all_box_preds = outputs['pred_box']
-
-        # post process (no NMS)
-        bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
-        outputs = {
-            "scores": scores,
-            "labels": labels,
-            "bboxes": bboxes
-        }
-        
-        return outputs 
-
-    def forward(self, x):
-        if not self.training:
-            return self.inference_o2o(x)
-        else:
-            # ---------------- Backbone ----------------
-            pyramid_feats = self.backbone(x)
-            # ---------------- Neck: SPP ----------------
-            pyramid_feats[-1] = self.neck(pyramid_feats[-1])
-
-            # ---------------- Neck: PaFPN ----------------
-            pyramid_feats = self.fpn(pyramid_feats)
-
-            # ---------------- Heads ----------------
-            o2m_cls_feats, o2m_reg_feats = self.head_o2m(pyramid_feats)
-
-            # ---------------- Preds ----------------
-            outputs_o2m = self.pred_o2m(o2m_cls_feats, o2m_reg_feats)
-            outputs_o2m['image_size'] = [x.shape[2], x.shape[3]]
-            
-            # ---------------- Heads (one-to-one) ----------------
-            o2o_cls_feats, o2o_reg_feats = self.head_o2o([feat.detach() for feat in pyramid_feats])
-
-            # ---------------- Preds (one-to-one) ----------------
-            outputs_o2o = self.pred_o2o(o2o_cls_feats, o2o_reg_feats)
-            outputs_o2o['image_size'] = [x.shape[2], x.shape[3]]
-
-            outputs = {
-                "outputs_o2m": outputs_o2m,
-                "outputs_o2o": outputs_o2o,
-            }
-            
-            return outputs 

+ 0 - 179
yolo/models/yolov8_e2e/yolov8_head.py

@@ -1,179 +0,0 @@
-import torch
-import torch.nn as nn
-
-try:
-    from .yolov8_basic import BasicConv
-except:
-    from  yolov8_basic import BasicConv
-
-
-# -------------------- Detection Head --------------------
-## Single-level Detection Head
-class DetHead(nn.Module):
-    def __init__(self,
-                 in_dim       :int  = 256,
-                 cls_head_dim :int  = 256,
-                 reg_head_dim :int  = 256,
-                 num_cls_head :int  = 2,
-                 num_reg_head :int  = 2,
-                 act_type     :str  = "silu",
-                 norm_type    :str  = "BN",
-                 depthwise    :bool = False):
-        super().__init__()
-        # --------- Basic Parameters ----------
-        self.in_dim = in_dim
-        self.num_cls_head = num_cls_head
-        self.num_reg_head = num_reg_head
-        self.act_type = act_type
-        self.norm_type = norm_type
-        self.depthwise = depthwise
-        
-        # --------- Network Parameters ----------
-        ## cls head
-        cls_feats = []
-        self.cls_head_dim = cls_head_dim
-        for i in range(num_cls_head):
-            if i == 0:
-                cls_feats.append(
-                    BasicConv(in_dim, self.cls_head_dim,
-                              kernel_size=3, padding=1, stride=1, 
-                              act_type=act_type,
-                              norm_type=norm_type,
-                              depthwise=depthwise)
-                              )
-            else:
-                cls_feats.append(
-                    BasicConv(self.cls_head_dim, self.cls_head_dim,
-                              kernel_size=3, padding=1, stride=1, 
-                              act_type=act_type,
-                              norm_type=norm_type,
-                              depthwise=depthwise)
-                              )
-        ## reg head
-        reg_feats = []
-        self.reg_head_dim = reg_head_dim
-        for i in range(num_reg_head):
-            if i == 0:
-                reg_feats.append(
-                    BasicConv(in_dim, self.reg_head_dim,
-                              kernel_size=3, padding=1, stride=1, 
-                              act_type=act_type,
-                              norm_type=norm_type,
-                              depthwise=depthwise)
-                              )
-            else:
-                reg_feats.append(
-                    BasicConv(self.reg_head_dim, self.reg_head_dim,
-                              kernel_size=3, padding=1, stride=1, 
-                              act_type=act_type,
-                              norm_type=norm_type,
-                              depthwise=depthwise)
-                              )
-        self.cls_feats = nn.Sequential(*cls_feats)
-        self.reg_feats = nn.Sequential(*reg_feats)
-
-        self.init_weights()
-        
-    def init_weights(self):
-        """Initialize the parameters."""
-        for m in self.modules():
-            if isinstance(m, torch.nn.Conv2d):
-                # In order to be consistent with the source code,
-                # reset the Conv2d initialization parameters
-                m.reset_parameters()
-
-    def forward(self, x):
-        """
-            in_feats: (Tensor) [B, C, H, W]
-        """
-        cls_feats = self.cls_feats(x)
-        reg_feats = self.reg_feats(x)
-
-        return cls_feats, reg_feats
-    
-## Multi-level Detection Head
-class Yolov8DetHead(nn.Module):
-    def __init__(self, cfg, in_dims):
-        super().__init__()
-        ## ----------- Network Parameters -----------
-        self.multi_level_heads = nn.ModuleList(
-            [DetHead(in_dim       = in_dims[level],
-                     cls_head_dim = max(in_dims[0], min(cfg.num_classes, 128)),
-                     reg_head_dim = max(in_dims[0]//4, 16, 4*cfg.reg_max),
-                     num_cls_head = cfg.num_cls_head,
-                     num_reg_head = cfg.num_reg_head,
-                     act_type     = cfg.head_act,
-                     norm_type    = cfg.head_norm,
-                     depthwise    = cfg.head_depthwise)
-                     for level in range(cfg.num_levels)
-                     ])
-        # --------- Basic Parameters ----------
-        self.in_dims = in_dims
-        self.cls_head_dim = self.multi_level_heads[0].cls_head_dim
-        self.reg_head_dim = self.multi_level_heads[0].reg_head_dim
-
-
-    def forward(self, feats):
-        """
-            feats: List[(Tensor)] [[B, C, H, W], ...]
-        """
-        cls_feats = []
-        reg_feats = []
-        for feat, head in zip(feats, self.multi_level_heads):
-            # ---------------- Pred ----------------
-            cls_feat, reg_feat = head(feat)
-
-            cls_feats.append(cls_feat)
-            reg_feats.append(reg_feat)
-
-        return cls_feats, reg_feats
-
-
-if __name__=='__main__':
-    import time
-    from thop import profile
-    # Model config
-    
-    # YOLOv8-Base config
-    class Yolov8BaseConfig(object):
-        def __init__(self) -> None:
-            # ---------------- Model config ----------------
-            self.width    = 0.50
-            self.depth    = 0.34
-            self.ratio    = 2.0
-            self.reg_max  = 16
-            self.out_stride = [8, 16, 32]
-            self.max_stride = 32
-            self.num_levels = 3
-            ## Head
-            self.head_act  = 'lrelu'
-            self.head_norm = 'BN'
-            self.head_depthwise = False
-            self.num_cls_head   = 2
-            self.num_reg_head   = 2
-
-    cfg = Yolov8BaseConfig()
-    cfg.num_classes = 20
-
-    # Build a head
-    fpn_dims = [128, 256, 512]
-    pyramid_feats = [torch.randn(1, fpn_dims[0], 80, 80),
-                     torch.randn(1, fpn_dims[1], 40, 40),
-                     torch.randn(1, fpn_dims[2], 20, 20)]
-    head = Yolov8DetHead(cfg, fpn_dims)
-
-
-    # Inference
-    t0 = time.time()
-    cls_feats, reg_feats = head(pyramid_feats)
-    t1 = time.time()
-    print('Time: ', t1 - t0)
-    print("====== Yolov8 Head output ======")
-    for level, (cls_f, reg_f) in enumerate(zip(cls_feats, reg_feats)):
-        print("- Level-{} : ".format(level), cls_f.shape, reg_f.shape)
-
-    flops, params = profile(head, inputs=(pyramid_feats, ), verbose=False)
-    print('==============================')
-    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))
-    

+ 0 - 85
yolo/models/yolov8_e2e/yolov8_neck.py

@@ -1,85 +0,0 @@
-import torch
-import torch.nn as nn
-
-try:
-    from .yolov8_basic import BasicConv
-except:
-    from  yolov8_basic import BasicConv
-    
-
-# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
-class SPPF(nn.Module):
-    """
-        This code referenced to https://github.com/ultralytics/yolov5
-    """
-    def __init__(self, cfg, in_dim, out_dim):
-        super().__init__()
-        ## ----------- Basic Parameters -----------
-        inter_dim = round(in_dim * cfg.neck_expand_ratio)
-        self.out_dim = out_dim
-        ## ----------- Network Parameters -----------
-        self.cv1 = BasicConv(in_dim, inter_dim,
-                             kernel_size=1, padding=0, stride=1,
-                             act_type=cfg.neck_act, norm_type=cfg.neck_norm)
-        self.cv2 = BasicConv(inter_dim * 4, out_dim,
-                             kernel_size=1, padding=0, stride=1,
-                             act_type=cfg.neck_act, norm_type=cfg.neck_norm)
-        self.m = nn.MaxPool2d(kernel_size=cfg.spp_pooling_size,
-                              stride=1,
-                              padding=cfg.spp_pooling_size // 2)
-
-        # Initialize all layers
-        self.init_weights()
-
-    def init_weights(self):
-        """Initialize the parameters."""
-        for m in self.modules():
-            if isinstance(m, torch.nn.Conv2d):
-                # In order to be consistent with the source code,
-                # reset the Conv2d initialization parameters
-                m.reset_parameters()
-
-    def forward(self, x):
-        x = self.cv1(x)
-        y1 = self.m(x)
-        y2 = self.m(y1)
-
-        return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
-
-
-if __name__=='__main__':
-    import time
-    from thop import profile
-    # Model config
-    
-    # YOLOv8-Base config
-    class Yolov8BaseConfig(object):
-        def __init__(self) -> None:
-            # ---------------- Model config ----------------
-            self.out_stride = 32
-            self.max_stride = 32
-            ## Neck
-            self.neck_act       = 'lrelu'
-            self.neck_norm      = 'BN'
-            self.neck_depthwise = False
-            self.neck_expand_ratio = 0.5
-            self.spp_pooling_size  = 5
-
-    cfg = Yolov8BaseConfig()
-    # Build a head
-    in_dim  = 512
-    out_dim = 512
-    neck = SPPF(cfg, in_dim, out_dim)
-
-    # Inference
-    x = torch.randn(1, in_dim, 20, 20)
-    t0 = time.time()
-    output = neck(x)
-    t1 = time.time()
-    print('Time: ', t1 - t0)
-    print('Neck output: ', output.shape)
-
-    flops, params = profile(neck, inputs=(x, ), verbose=False)
-    print('==============================')
-    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))

+ 0 - 152
yolo/models/yolov8_e2e/yolov8_pafpn.py

@@ -1,152 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from typing import List
-
-try:
-    from .yolov8_basic import BasicConv, ELANLayer
-except:
-    from  yolov8_basic import BasicConv, ELANLayer
-
-
-# YOLOv8's PaFPN
-class Yolov8PaFPN(nn.Module):
-    def __init__(self,
-                 cfg,
-                 in_dims :List = [256, 512, 1024],
-                 ) -> None:
-        super(Yolov8PaFPN, self).__init__()
-        print('==============================')
-        print('FPN: {}'.format("Yolo PaFPN"))
-        # --------------------------- Basic Parameters ---------------------------
-        self.in_dims = in_dims[::-1]
-        self.out_dims = [round(256*cfg.width), round(512*cfg.width), round(512*cfg.width*cfg.ratio)]
-
-        # ----------------------------- Yolov8's Top-down FPN -----------------------------
-        ## P5 -> P4
-        self.top_down_layer_1 = ELANLayer(in_dim     = self.in_dims[0] + self.in_dims[1],
-                                          out_dim    = round(512*cfg.width),
-                                          expansion  = 0.5,
-                                          num_blocks = round(3 * cfg.depth),
-                                          shortcut   = False,
-                                          act_type   = cfg.fpn_act,
-                                          norm_type  = cfg.fpn_norm,
-                                          depthwise  = cfg.fpn_depthwise,
-                                          )
-        ## P4 -> P3
-        self.top_down_layer_2 = ELANLayer(in_dim     = self.in_dims[2] + round(512*cfg.width),
-                                          out_dim    = round(256*cfg.width),
-                                          expansion  = 0.5,
-                                          num_blocks = round(3 * cfg.depth),
-                                          shortcut   = False,
-                                          act_type   = cfg.fpn_act,
-                                          norm_type  = cfg.fpn_norm,
-                                          depthwise  = cfg.fpn_depthwise,
-                                          )
-        # ----------------------------- Yolov8's Bottom-up PAN -----------------------------
-        ## P3 -> P4
-        self.dowmsample_layer_1 = BasicConv(round(256*cfg.width), round(256*cfg.width),
-                                            kernel_size=3, padding=1, stride=2,
-                                            act_type=cfg.fpn_act, norm_type=cfg.fpn_norm, depthwise=cfg.fpn_depthwise)
-        self.bottom_up_layer_1 = ELANLayer(in_dim     = round(256*cfg.width) + round(512*cfg.width),
-                                           out_dim    = round(512*cfg.width),
-                                           expansion  = 0.5,
-                                           num_blocks = round(3 * cfg.depth),
-                                           shortcut   = False,
-                                           act_type   = cfg.fpn_act,
-                                           norm_type  = cfg.fpn_norm,
-                                           depthwise  = cfg.fpn_depthwise,
-                                           )
-        ## P4 -> P5
-        self.dowmsample_layer_2 = BasicConv(round(512*cfg.width), round(512*cfg.width),
-                                            kernel_size=3, padding=1, stride=2,
-                                            act_type=cfg.fpn_act, norm_type=cfg.fpn_norm, depthwise=cfg.fpn_depthwise)
-        self.bottom_up_layer_2 = ELANLayer(in_dim     = round(512*cfg.width) + self.in_dims[0],
-                                           out_dim    = round(512*cfg.width*cfg.ratio),
-                                           expansion  = 0.5,
-                                           num_blocks = round(3 * cfg.depth),
-                                           shortcut   = False,
-                                           act_type   = cfg.fpn_act,
-                                           norm_type  = cfg.fpn_norm,
-                                           depthwise  = cfg.fpn_depthwise,
-                                           )
-
-        self.init_weights()
-        
-    def init_weights(self):
-        """Initialize the parameters."""
-        for m in self.modules():
-            if isinstance(m, torch.nn.Conv2d):
-                # In order to be consistent with the source code,
-                # reset the Conv2d initialization parameters
-                m.reset_parameters()
-
-    def forward(self, features):
-        c3, c4, c5 = features
-
-        # ------------------ Top down FPN ------------------
-        ## P5 -> P4
-        p5_up = F.interpolate(c5, scale_factor=2.0)
-        p4 = self.top_down_layer_1(torch.cat([p5_up, c4], dim=1))
-
-        ## P4 -> P3
-        p4_up = F.interpolate(p4, scale_factor=2.0)
-        p3 = self.top_down_layer_2(torch.cat([p4_up, c3], dim=1))
-
-        # ------------------ Bottom up FPN ------------------
-        ## p3 -> P4
-        p3_ds = self.dowmsample_layer_1(p3)
-        p4 = self.bottom_up_layer_1(torch.cat([p3_ds, p4], dim=1))
-
-        ## P4 -> 5
-        p4_ds = self.dowmsample_layer_2(p4)
-        p5 = self.bottom_up_layer_2(torch.cat([p4_ds, c5], dim=1))
-
-        out_feats = [p3, p4, p5] # [P3, P4, P5]
-                
-        return out_feats
-    
-
-if __name__=='__main__':
-    import time
-    from thop import profile
-    # Model config
-    
-    # YOLOv8-Base config
-    class Yolov8BaseConfig(object):
-        def __init__(self) -> None:
-            # ---------------- Model config ----------------
-            self.width    = 0.50
-            self.depth    = 0.34
-            self.ratio    = 2.0
-            self.out_stride = [8, 16, 32]
-            self.max_stride = 32
-            self.num_levels = 3
-            ## FPN
-            self.fpn_act  = 'silu'
-            self.fpn_norm = 'BN'
-            self.fpn_depthwise = False
-            ## Head
-            self.head_dim = 256
-
-    cfg = Yolov8BaseConfig()
-    # Build a head
-    in_dims  = [128, 256, 512]
-    fpn = Yolov8PaFPN(cfg, in_dims)
-
-    # Inference
-    x = [torch.randn(1, in_dims[0], 80, 80),
-         torch.randn(1, in_dims[1], 40, 40),
-         torch.randn(1, in_dims[2], 20, 20)]
-    t0 = time.time()
-    output = fpn(x)
-    t1 = time.time()
-    print('Time: ', t1 - t0)
-    print('====== FPN output ====== ')
-    for level, feat in enumerate(output):
-        print("- Level-{} : ".format(level), feat.shape)
-
-    flops, params = profile(fpn, inputs=(x, ), verbose=False)
-    print('==============================')
-    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))

+ 0 - 210
yolo/models/yolov8_e2e/yolov8_pred.py

@@ -1,210 +0,0 @@
-import math
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-# -------------------- Detection Pred Layer --------------------
-## Single-level pred layer
-class DetPredLayer(nn.Module):
-    def __init__(self,
-                 cls_dim     :int = 256,
-                 reg_dim     :int = 256,
-                 stride      :int = 32,
-                 reg_max     :int = 16,
-                 num_classes :int = 80,
-                 num_coords  :int = 4):
-        super().__init__()
-        # --------- Basic Parameters ----------
-        self.stride = stride
-        self.cls_dim = cls_dim
-        self.reg_dim = reg_dim
-        self.reg_max = reg_max
-        self.num_classes = num_classes
-        self.num_coords = num_coords
-
-        # --------- Network Parameters ----------
-        self.cls_pred = nn.Conv2d(cls_dim, num_classes, kernel_size=1)
-        self.reg_pred = nn.Conv2d(reg_dim, num_coords, kernel_size=1)                
-
-        self.init_bias()
-        
-    def init_bias(self):
-        # cls pred bias
-        b = self.cls_pred.bias.view(1, -1)
-        b.data.fill_(math.log(5 / self.num_classes / (640. / self.stride) ** 2))
-        self.cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        # reg pred bias
-        b = self.reg_pred.bias.view(-1, )
-        b.data.fill_(1.0)
-        self.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        w = self.reg_pred.weight
-        w.data.fill_(0.)
-        self.reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
-
-    def generate_anchors(self, fmp_size):
-        """
-            fmp_size: (List) [H, W]
-        """
-        # generate grid cells
-        fmp_h, fmp_w = fmp_size
-        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
-        # [H, W, 2] -> [HW, 2]
-        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
-        anchors += 0.5  # add center offset
-        anchors *= self.stride
-
-        return anchors
-        
-    def forward(self, cls_feat, reg_feat):
-        # pred
-        cls_pred = self.cls_pred(cls_feat)
-        reg_pred = self.reg_pred(reg_feat)
-
-        # generate anchor boxes: [M, 4]
-        B, _, H, W = cls_pred.size()
-        fmp_size = [H, W]
-        anchors = self.generate_anchors(fmp_size)
-        anchors = anchors.to(cls_pred.device)
-        # stride tensor: [M, 1]
-        stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride
-        
-        # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
-        cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
-        reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
-        
-        # output dict
-        outputs = {"pred_cls": cls_pred,            # List(Tensor) [B, M, C]
-                   "pred_reg": reg_pred,            # List(Tensor) [B, M, 4*(reg_max)]
-                   "anchors": anchors,              # List(Tensor) [M, 2]
-                   "strides": self.stride,          # List(Int) = [8, 16, 32]
-                   "stride_tensor": stride_tensor   # List(Tensor) [M, 1]
-                   }
-
-        return outputs
-
-## Multi-level pred layer
-class Yolov8DetPredLayer(nn.Module):
-    def __init__(self,
-                 cfg,
-                 cls_dim,
-                 reg_dim,
-                 ):
-        super().__init__()
-        # --------- Basic Parameters ----------
-        self.cfg = cfg
-        self.cls_dim = cls_dim
-        self.reg_dim = reg_dim
-
-        # ----------- Network Parameters -----------
-        ## pred layers
-        self.multi_level_preds = nn.ModuleList(
-            [DetPredLayer(cls_dim     = cls_dim,
-                          reg_dim     = reg_dim,
-                          stride      = cfg.out_stride[level],
-                          reg_max     = cfg.reg_max,
-                          num_classes = cfg.num_classes,
-                          num_coords  = 4 * cfg.reg_max)
-                          for level in range(cfg.num_levels)
-                          ])
-        ## proj conv
-        proj_init = torch.arange(cfg.reg_max, dtype=torch.float)
-        self.proj_conv = nn.Conv2d(cfg.reg_max, 1, kernel_size=1, bias=False).requires_grad_(False)
-        self.proj_conv.weight.data[:] = nn.Parameter(proj_init.view([1, cfg.reg_max, 1, 1]), requires_grad=False)
-
-    def forward(self, cls_feats, reg_feats):
-        all_anchors = []
-        all_strides = []
-        all_cls_preds = []
-        all_reg_preds = []
-        all_box_preds = []
-        for level in range(self.cfg.num_levels):
-            # -------------- Single-level prediction --------------
-            outputs = self.multi_level_preds[level](cls_feats[level], reg_feats[level])
-
-            # -------------- Decode bbox --------------
-            B, M = outputs["pred_reg"].shape[:2]
-            # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max]
-            delta_pred = outputs["pred_reg"].reshape([B, M, 4, self.cfg.reg_max])
-            # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
-            delta_pred = delta_pred.permute(0, 3, 2, 1).contiguous()
-            # [B, reg_max, 4, M] -> [B, 1, 4, M]
-            delta_pred = self.proj_conv(F.softmax(delta_pred, dim=1))
-            # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
-            delta_pred = delta_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
-            ## tlbr -> xyxy
-            x1y1_pred = outputs["anchors"][None] - delta_pred[..., :2] * self.cfg.out_stride[level]
-            x2y2_pred = outputs["anchors"][None] + delta_pred[..., 2:] * self.cfg.out_stride[level]
-            box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
-
-            # collect results
-            all_cls_preds.append(outputs["pred_cls"])
-            all_reg_preds.append(outputs["pred_reg"])
-            all_box_preds.append(box_pred)
-            all_anchors.append(outputs["anchors"])
-            all_strides.append(outputs["stride_tensor"])
-        
-        # output dict
-        outputs = {"pred_cls":      all_cls_preds,         # List(Tensor) [B, M, C]
-                   "pred_reg":      all_reg_preds,         # List(Tensor) [B, M, 4*(reg_max)]
-                   "pred_box":      all_box_preds,         # List(Tensor) [B, M, 4]
-                   "anchors":       all_anchors,           # List(Tensor) [M, 2]
-                   "stride_tensor": all_strides,           # List(Tensor) [M, 1]
-                   "strides":       self.cfg.out_stride,   # List(Int) = [8, 16, 32]
-                   }
-
-        return outputs
-
-
-if __name__=='__main__':
-    import time
-    from thop import profile
-    # Model config
-    
-    # YOLOv8-Base config
-    class Yolov8BaseConfig(object):
-        def __init__(self) -> None:
-            # ---------------- Model config ----------------
-            self.width    = 1.0
-            self.depth    = 1.0
-            self.ratio    = 1.0
-            self.reg_max  = 16
-            self.out_stride = [8, 16, 32]
-            self.max_stride = 32
-            self.num_levels = 3
-            ## Head
-
-    cfg = Yolov8BaseConfig()
-    cfg.num_classes = 20
-    cls_dim = 128
-    reg_dim = 64
-    # Build a pred layer
-    pred = Yolov8DetPredLayer(cfg, cls_dim, reg_dim)
-
-    # Inference
-    cls_feats = [torch.randn(1, cls_dim, 80, 80),
-                 torch.randn(1, cls_dim, 40, 40),
-                 torch.randn(1, cls_dim, 20, 20),]
-    reg_feats = [torch.randn(1, reg_dim, 80, 80),
-                 torch.randn(1, reg_dim, 40, 40),
-                 torch.randn(1, reg_dim, 20, 20),]
-    t0 = time.time()
-    output = pred(cls_feats, reg_feats)
-    t1 = time.time()
-    print('Time: ', t1 - t0)
-    print('====== Pred output ======= ')
-    pred_cls = output["pred_cls"]
-    pred_reg = output["pred_reg"]
-    pred_box = output["pred_box"]
-    anchors  = output["anchors"]
-    
-    for level in range(cfg.num_levels):
-        print("- Level-{} : classification   -> {}".format(level, pred_cls[level].shape))
-        print("- Level-{} : delta regression -> {}".format(level, pred_reg[level].shape))
-        print("- Level-{} : bbox regression  -> {}".format(level, pred_box[level].shape))
-        print("- Level-{} : anchor boxes     -> {}".format(level, anchors[level].shape))
-
-    flops, params = profile(pred, inputs=(cls_feats, reg_feats, ), verbose=False)
-    print('==============================')
-    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))