Bläddra i källkod

add iclab for imagenet pretraining

yjh0410 1 år sedan
förälder
incheckning
ce28ce6da0

+ 11 - 0
iclab/.gitignore

@@ -0,0 +1,11 @@
+*.pt
+*.pth
+*.pkl
+*.onnx
+*.pyc
+*.zip
+weights
+__pycache__
+.vscode
+data/cifar_data/
+data/mnist_data/

+ 61 - 0
iclab/README.md

@@ -0,0 +1,61 @@
+# General Image Classification Laboratory
+
+
+## Train a CNN
+We have kindly provided a bash script `train.sh` to train the models. You can modify some hyperparameters in the script file according to your own needs.
+
+For example, we are going to use 8 GPUs to train `ELANDarkNet-S` designed in this repo, so we can use the following command:
+
+```Shell
+bash train.sh elandarknet_s imagenet_1k path/to/imagnet_1k 8 1699 None
+```
+
+## Evaluate a CNN
+- Evaluate the `top1 & top5` accuracy of `ViT-Tiny` on ImageNet-1K dataset:
+```Shell
+python main.py --cuda -dataset imagenet_1k --root path/to/imagnet_1k -m elandarknet_s --batch_size 256 --img_size 224 --eval --resume path/to/elandarknet_s.pth
+```
+
+
+## Experimental results
+Tips:
+- **Weak augmentation:** includes `random hflip` and `random crop resize`.
+- **Strong augmentation:** includes `mixup`, `cutmix`, `rand aug`, `random erase` and so on. However, we don't use the strong augmentation.
+- The `AdamW` with `weight decay = 0.05` and `base lr = 4e-3 (for bs of 4096)` is deployed as the optimzier, and the `CosineAnnealingLR` is deployed as the lr scheduler, where the `min lr` is set to 1e-6.
+
+### ImageNet-1K
+* Modified DarkNet (Yolov3's DarkNet with width and depth scaling factors)
+
+|    Model      | Augment | Batch | Epoch | size | acc@1 | GFLOPs | Params |  Weight |
+|---------------|---------|-------|-------|------|-------|--------|--------|---------|
+| DarkNet-S     |   weak  |  4096 |  100  | 224  |  68.5 |  1.6   |  4.6 M | [ckpt](https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/darknet_s_in1k_68.5.pth) |
+| DarkNet-M     |   weak  |  4096 |  100  | 224  |       |        |        |  |
+| DarkNet-L     |   weak  |  4096 |  100  | 224  |       |        |        |  |
+| DarkNet-X     |   weak  |  4096 |  100  | 224  |       |        |        |  |
+
+* Modified CSPDarkNet (Yolov5's DarkNet with width and depth scaling factors)
+
+|    Model      | Augment | Batch | Epoch | size | acc@1 | GFLOPs | Params |  Weight |
+|---------------|---------|-------|-------|------|-------|--------|--------|---------|
+| CSPDarkNet-S  |   weak  |  4096 |  100  | 224  |  70.2 |  1.3   | 4.0 M  | [ckpt](https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/cspdarknet_s_in1k_70.2.pth) |
+| CSPDarkNet-M  |   weak  |  4096 |  100  | 224  |       |        |        |  |
+| CSPDarkNet-L  |   weak  |  4096 |  100  | 224  |       |        |        |  |
+| CSPDarkNet-X  |   weak  |  4096 |  100  | 224  |       |        |        |  |
+
+* ElANDarkNet (Yolov8's backbone)
+
+|         Model          | Augment | Batch | Epoch | size | acc@1 | GFLOPs | Params  |  Weight |
+|------------------------|---------|-------|-------|------|-------|--------|---------|---------|
+| ElANDarkNet-N      |   weak  |  4096 |  100  | 224  |  62.1 |  0.38  | 1.36 M  | [ckpt](https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/elandarknet_n_in1k_62.1.pth) |
+| ElANDarkNet-S      |   weak  |  4096 |  100  | 224  |  71.3 |  1.48  | 4.94 M  | [ckpt](https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/elandarknet_s_in1k_71.3.pth) |
+| ElANDarkNet-M      |   weak  |  4096 |  100  | 224  |       |  4.67  | 11.60 M |  |
+| ElANDarkNet-L      |   weak  |  4096 |  100  | 224  |       |  10.47 | 19.66 M |  |
+| ElANDarkNet-X      |   weak  |  4096 |  100  | 224  |       |  20.56 | 37.86 M |  |
+
+
+* GELAN (Proposed by YOLOv9)
+
+|     Model     | Augment | Batch | Epoch | size | acc@1 | GFLOPs | Params  |  Weight |
+|---------------|---------|-------|-------|------|-------|--------|---------|---------|
+| GELAN-S       |   weak  |  4096 |  100  | 224  | 68.4  |  0.9   | 1.9 M   | [ckpt](https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/gelan_s_in1k_68.4.pth) |
+| GELAN-C       |   weak  |  4096 |  100  | 224  |   |  5.2   | 8.8 M   | [ckpt]()|

+ 36 - 0
iclab/data/__init__.py

@@ -0,0 +1,36 @@
+import torch.utils.data as data
+
+from .cifar import CifarDataset
+from .mnist import MnistDataset
+from .imagenet import ImageNet1KDataset
+from .custom import CustomDataset
+
+
+def build_dataset(args, transform=None, is_train=False):
+    if args.dataset == 'cifar10':
+        args.num_classes = 10
+        args.img_dim = 3
+        return CifarDataset(is_train, transform)
+    elif args.dataset == 'mnist':
+        args.num_classes = 10
+        args.img_dim = 1
+        return MnistDataset(is_train, transform)
+    elif args.dataset == 'imagenet_1k':
+        args.num_classes = 1000
+        args.img_dim = 3
+        return ImageNet1KDataset(args, is_train, transform)
+    elif args.dataset == 'custom':
+        assert args.num_classes is not None and isinstance(args.num_classes, int)
+        args.img_dim = 3
+        return CustomDataset(args, is_train, transform)
+    
+
+def build_dataloader(args, dataset, is_train=False):
+    if is_train:
+        sampler = data.distributed.DistributedSampler(dataset) if args.distributed else data.RandomSampler(dataset)
+        batch_sampler_train = data.BatchSampler(sampler, args.batch_size // args.world_size, drop_last=True if is_train else False)
+        dataloader = data.DataLoader(dataset, batch_sampler=batch_sampler_train, num_workers=args.num_workers, pin_memory=True)
+    else:
+        dataloader = data.DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
+
+    return dataloader

+ 78 - 0
iclab/data/cifar.py

@@ -0,0 +1,78 @@
+import os
+import numpy as np
+import torch.utils.data as data
+import torchvision.transforms as T
+from torchvision.datasets import CIFAR10
+
+
+class CifarDataset(data.Dataset):
+    def __init__(self, is_train=False, transform=None):
+        super().__init__()
+        # ----------------- basic parameters -----------------
+        self.is_train   = is_train
+        self.pixel_mean = [0.0]
+        self.pixel_std  = [1.0]
+        self.image_set  = 'train' if is_train else 'val'
+        # ----------------- dataset & transforms -----------------
+        self.transform = self.build_transform()
+        path = os.path.dirname(os.path.abspath(__file__))
+        if is_train:
+            self.dataset = CIFAR10(os.path.join(path, 'cifar_data/'), train=True, download=True, transform=self.transform)
+        else:
+            self.dataset = CIFAR10(os.path.join(path, 'cifar_data/'), train=False, download=True, transform=self.transform)
+
+    def __len__(self):
+        return len(self.dataset)
+    
+    def __getitem__(self, index):
+        image, target = self.dataset[index]
+            
+        return image, target
+    
+    def pull_image(self, index):
+        # laod data
+        image, target = self.dataset[index]
+
+        # ------- Denormalize image -------
+        ## [C, H, W] -> [H, W, C], torch.Tensor -> numpy.adnarry
+        image = image.permute(1, 2, 0).numpy()
+        ## Denomalize: I = I_n * std + mean, I = I * 255
+        image = (image * self.pixel_std + self.pixel_mean) * 255.
+
+        image = image.astype(np.uint8)
+        image = image.copy()
+
+        return image, target
+
+    def build_transform(self):
+        if self.is_train:
+            transforms = T.Compose([T.ToTensor(), T.RandomCrop(size=32, padding=8)])
+        else:
+            transforms = T.Compose([T.ToTensor()])
+
+        return transforms
+
+if __name__ == "__main__":
+    import cv2
+    import argparse
+    
+    parser = argparse.ArgumentParser(description='Cifar-Dataset')
+
+    # opt
+    parser.add_argument('--is_train', action="store_true", default=False,
+                        help='train or not.')
+    
+    args = parser.parse_args()
+
+    # dataset
+    dataset = CifarDataset(is_train=args.is_train)  
+    print('Dataset size: ', len(dataset))
+
+    for i in range(1000):
+        image, target = dataset.pull_image(i)
+        # to BGR
+        image = image[..., (2, 1, 0)]
+
+        cv2.imshow('image', image)
+        cv2.waitKey(0)
+

+ 109 - 0
iclab/data/custom.py

@@ -0,0 +1,109 @@
+import os
+import PIL
+import numpy as np
+from timm.data import create_transform
+import torch.utils.data as data
+import torchvision.transforms as T
+from torchvision.datasets import ImageFolder
+import torchvision.transforms as transforms
+
+
+class CustomDataset(data.Dataset):
+    def __init__(self, args, is_train=False, transform=None):
+        super().__init__()
+        # ----------------- basic parameters -----------------
+        self.args = args
+        self.is_train   = is_train
+        self.pixel_mean = [0.485, 0.456, 0.406]
+        self.pixel_std  = [0.229, 0.224, 0.225]
+        print("Pixel mean: {}".format(self.pixel_mean))
+        print("Pixel std:  {}".format(self.pixel_std))
+        self.image_set = 'train' if is_train else 'val'
+        self.data_path = os.path.join(args.root, self.image_set)
+        # ----------------- dataset & transforms -----------------
+        self.transform = transform if transform is not None else self.build_transform(args)
+        self.dataset = ImageFolder(root=self.data_path, transform=self.transform)
+
+    def __len__(self):
+        return len(self.dataset)
+    
+    def __getitem__(self, index):
+        image, target = self.dataset[index]
+
+        return image, target
+    
+    def pull_image(self, index):
+        # laod data
+        image, target = self.dataset[index]
+
+        # denormalize image
+        image = image.permute(1, 2, 0).numpy()
+        image = (image * self.pixel_std + self.pixel_mean) * 255.
+        image = image.astype(np.uint8)
+        image = image.copy()
+
+        return image, target
+
+    def build_transform(self, args):
+        if self.is_train:
+            transforms = create_transform(input_size    = args.img_size,
+                                          is_training   = True,
+                                          color_jitter  = args.color_jitter,
+                                          auto_augment  = args.aa,
+                                          interpolation = 'bicubic',
+                                          re_prob       = args.reprob,
+                                          re_mode       = args.remode,
+                                          re_count      = args.recount,
+                                          mean          = self.pixel_mean,
+                                          std           = self.pixel_std,
+                                          )
+        else:
+            t = []
+            if args.img_size <= 224:
+                crop_pct = 224 / 256
+            else:
+                crop_pct = 1.0
+            size = int(args.img_size / crop_pct)
+            t.append(
+                T.Resize(size, interpolation=PIL.Image.BICUBIC),  # to maintain same ratio w.r.t. 224 images
+            )
+            t.append(T.CenterCrop(args.img_size))
+            t.append(T.ToTensor())
+            t.append(T.Normalize(self.pixel_mean, self.pixel_std))
+            transforms = T.Compose(t)
+
+        return transforms
+
+
+if __name__ == "__main__":
+    import cv2
+    import torch
+    import argparse
+    
+    parser = argparse.ArgumentParser(description='Custom-Dataset')
+
+    # opt
+    parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/classification/dataset/Animals/',
+                        help='data root')
+    parser.add_argument('--img_size', default=224, type=int,
+                        help='input image size.')
+    args = parser.parse_args()
+
+    # Transforms
+    train_transform = transforms.Compose([
+            transforms.RandomResizedCrop(args.img_size, scale=(0.2, 1.0), interpolation=3),  # 3 is bicubic
+            transforms.RandomHorizontalFlip(),
+            transforms.ToTensor(),
+            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
+  
+    # Dataset
+    dataset = CustomDataset(args, is_train=True, transform=train_transform)  
+    print('Dataset size: ', len(dataset))
+
+    for i in range(1000):
+        image, target = dataset.pull_image(i)
+        # to BGR
+        image = image[..., (2, 1, 0)]
+
+        cv2.imshow('image', image)
+        cv2.waitKey(0)

+ 109 - 0
iclab/data/imagenet.py

@@ -0,0 +1,109 @@
+import os
+import PIL
+import numpy as np
+from timm.data import create_transform
+import torch.utils.data as data
+import torchvision.transforms as T
+from torchvision.datasets import ImageFolder
+import torchvision.transforms as transforms
+
+
+class ImageNet1KDataset(data.Dataset):
+    def __init__(self, args, is_train=False, transform=None):
+        super().__init__()
+        # ----------------- basic parameters -----------------
+        self.args = args
+        self.is_train   = is_train
+        self.pixel_mean = [0.485, 0.456, 0.406]
+        self.pixel_std  = [0.229, 0.224, 0.225]
+        print("Pixel mean: {}".format(self.pixel_mean))
+        print("Pixel std:  {}".format(self.pixel_std))
+        self.image_set = 'train' if is_train else 'val'
+        self.data_path = os.path.join(args.root, self.image_set)
+        # ----------------- dataset & transforms -----------------
+        self.transform = transform if transform is not None else self.build_transform(args)
+        self.dataset = ImageFolder(root=self.data_path, transform=self.transform)
+
+    def __len__(self):
+        return len(self.dataset)
+    
+    def __getitem__(self, index):
+        image, target = self.dataset[index]
+
+        return image, target
+    
+    def pull_image(self, index):
+        # laod data
+        image, target = self.dataset[index]
+
+        # denormalize image
+        image = image.permute(1, 2, 0).numpy()
+        image = (image * self.pixel_std + self.pixel_mean) * 255.
+        image = image.astype(np.uint8)
+        image = image.copy()
+
+        return image, target
+
+    def build_transform(self, args):
+        if self.is_train:
+            transforms = create_transform(input_size    = args.img_size,
+                                          is_training   = True,
+                                          color_jitter  = args.color_jitter,
+                                          auto_augment  = args.aa,
+                                          interpolation = 'bicubic',
+                                          re_prob       = args.reprob,
+                                          re_mode       = args.remode,
+                                          re_count      = args.recount,
+                                          mean          = self.pixel_mean,
+                                          std           = self.pixel_std,
+                                          )
+        else:
+            t = []
+            if args.img_size <= 224:
+                crop_pct = 224 / 256
+            else:
+                crop_pct = 1.0
+            size = int(args.img_size / crop_pct)
+            t.append(
+                T.Resize(size, interpolation=PIL.Image.BICUBIC),  # to maintain same ratio w.r.t. 224 images
+            )
+            t.append(T.CenterCrop(args.img_size))
+            t.append(T.ToTensor())
+            t.append(T.Normalize(self.pixel_mean, self.pixel_std))
+            transforms = T.Compose(t)
+
+        return transforms
+
+
+if __name__ == "__main__":
+    import cv2
+    import torch
+    import argparse
+    
+    parser = argparse.ArgumentParser(description='ImageNet-Dataset')
+
+    # opt
+    parser.add_argument('--root', default='/mnt/share/ssd2/dataset/imagenet/',
+                        help='data root')
+    parser.add_argument('--img_size', default=224, type=int,
+                        help='input image size.')
+    args = parser.parse_args()
+
+    # Transforms
+    train_transform = transforms.Compose([
+            transforms.RandomResizedCrop(args.img_size, scale=(0.2, 1.0), interpolation=3),  # 3 is bicubic
+            transforms.RandomHorizontalFlip(),
+            transforms.ToTensor(),
+            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
+  
+    # Dataset
+    dataset = ImageNet1KDataset(args, is_train=True)  
+    print('Dataset size: ', len(dataset))
+
+    for i in range(1000):
+        image, target = dataset.pull_image(i)
+        # to BGR
+        image = image[..., (2, 1, 0)]
+
+        cv2.imshow('image', image)
+        cv2.waitKey(0)

+ 60 - 0
iclab/data/mnist.py

@@ -0,0 +1,60 @@
+import os
+import torch.utils.data as data
+import torchvision.transforms as T
+from torchvision.datasets import MNIST
+
+
+class MnistDataset(data.Dataset):
+    def __init__(self, is_train=False, transform=None):
+        super().__init__()
+        # ----------------- basic parameters -----------------
+        self.is_train   = is_train
+        self.pixel_mean = [0.]
+        self.pixel_std  = [1.]
+        self.image_set  = 'train' if is_train else 'val'
+        # ----------------- dataset & transforms -----------------
+        self.transform = self.build_transform()
+        path = os.path.dirname(os.path.abspath(__file__))
+        if is_train:
+            self.dataset = MNIST(os.path.join(path, 'mnist_data/'), train=True, download=True, transform=self.transform)
+        else:
+            self.dataset = MNIST(os.path.join(path, 'mnist_data/'), train=False, download=True, transform=self.transform)
+
+    def __len__(self):
+        return len(self.dataset)
+    
+    def __getitem__(self, index):
+        image, target = self.dataset[index]
+            
+        return image, target
+    
+    def pull_image(self, index):
+        # laod data
+        image, target = self.dataset[index]
+
+        # denormalize image
+        image = image.permute(1, 2, 0).numpy()
+        image = image.copy()
+
+        return image, target
+
+    def build_transform(self):
+        if self.is_train:
+            transforms = T.Compose([T.ToTensor(),])
+        else:
+            transforms = T.Compose([T.ToTensor(),])
+
+        return transforms
+
+if __name__ == "__main__":
+    import cv2
+
+    # dataset
+    dataset = MnistDataset(is_train=True)  
+    print('Dataset size: ', len(dataset))
+
+    for i in range(1000):
+        image, target = dataset.pull_image(i)
+
+        cv2.imshow('image', image)
+        cv2.waitKey(0)

+ 131 - 0
iclab/engine.py

@@ -0,0 +1,131 @@
+import sys
+import math
+import numpy as np
+
+import torch
+
+from utils.misc import MetricLogger, SmoothedValue
+from utils.misc import print_rank_0, all_reduce_mean, accuracy
+
+
+def train_one_epoch(args,
+                    device,
+                    model,
+                    model_ema,
+                    data_loader,
+                    optimizer,
+                    epoch,
+                    lr_scheduler_warmup,
+                    loss_scaler,
+                    criterion,
+                    local_rank=0,
+                    tblogger=None,
+                    mixup_fn=None):
+    model.train(True)
+    metric_logger = MetricLogger(delimiter="  ")
+    metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
+    header = 'Epoch: [{} / {}]'.format(epoch, args.max_epoch)
+    print_freq = 20
+    epoch_size = len(data_loader)
+
+    optimizer.zero_grad()
+
+    # train one epoch
+    for iter_i, (images, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+        ni = iter_i + epoch * epoch_size
+        nw = args.wp_epoch * epoch_size
+        # Warmup
+        if nw > 0 and ni < nw:
+            lr_scheduler_warmup(ni, optimizer)
+        elif ni == nw:
+            print("Warmup stage is over.")
+            lr_scheduler_warmup.set_lr(optimizer, args.base_lr)
+
+        # To device
+        images = images.to(device, non_blocking=True)
+        targets = targets.to(device, non_blocking=True)
+
+        # Mixup
+        if mixup_fn is not None:
+            images, targets = mixup_fn(images, targets)
+
+        # Inference
+        with torch.cuda.amp.autocast():
+            output = model(images)
+            loss = criterion(output, targets)
+
+        # Check loss
+        loss_value = loss.item()
+        if not math.isfinite(loss_value):
+            print("Loss is {}, stopping training".format(loss_value))
+            sys.exit(1)
+
+        # Backward & Optimize
+        loss /= args.grad_accumulate
+        loss_scaler(loss, optimizer, clip_grad=args.max_grad_norm, 
+                    parameters=model.parameters(), create_graph=False,
+                    update_grad=(iter_i + 1) % args.grad_accumulate == 0)
+        if (iter_i + 1) % args.grad_accumulate == 0:
+            optimizer.zero_grad()
+            if model_ema is not None:
+                model_ema.update(model)
+
+        if torch.cuda.is_available():
+            torch.cuda.synchronize()
+
+        # Logs
+        lr = optimizer.param_groups[0]["lr"]
+        metric_logger.update(loss=loss_value)
+        metric_logger.update(lr=lr)
+
+        loss_value_reduce = all_reduce_mean(loss_value)
+        if tblogger is not None and (iter_i + 1) % args.grad_accumulate == 0:
+            """ We use epoch_1000x as the x-axis in tensorboard.
+            This calibrates different curves when batch size changes.
+            """
+            epoch_1000x = int((iter_i / len(data_loader) + epoch) * 1000)
+            tblogger.add_scalar('loss', loss_value_reduce, epoch_1000x)
+            tblogger.add_scalar('lr', lr, epoch_1000x)
+
+    # gather the stats from all processes
+    metric_logger.synchronize_between_processes()
+    print_rank_0("Averaged stats: {}".format(metric_logger), local_rank)
+
+    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+
+@torch.no_grad()
+def evaluate(data_loader, model, device, local_rank):
+    criterion = torch.nn.CrossEntropyLoss()
+
+    metric_logger = MetricLogger(delimiter="  ")
+    header = 'Test:'
+
+    # switch to evaluation mode
+    model.eval()
+
+    for batch in metric_logger.log_every(data_loader, 10, header):
+        images = batch[0]
+        target = batch[-1]
+        images = images.to(device, non_blocking=True)
+        target = target.to(device, non_blocking=True)
+
+        # compute output
+        with torch.cuda.amp.autocast():
+            output = model(images)
+            loss = criterion(output, target)
+
+        acc1, acc5 = accuracy(output, target, topk=(1, 5))
+
+        batch_size = images.shape[0]
+        metric_logger.update(loss=loss.item())
+        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
+        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
+
+    # gather the stats from all processes
+    metric_logger.synchronize_between_processes()
+    print_rank_0('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
+                 .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss),
+                 local_rank)
+
+    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

+ 20 - 0
iclab/models/__init__.py

@@ -0,0 +1,20 @@
+from .elandarknet.build import build_elandarknet
+from .cspdarknet.build  import build_cspdarknet
+from .darknet.build     import build_darknet
+from .gelan.build       import build_gelan
+
+
+def build_model(args):
+    # --------------------------- ResNet series ---------------------------
+    if   'elandarknet' in args.model:
+        model = build_elandarknet(args)
+    elif 'cspdarknet' in args.model:
+        model = build_cspdarknet(args)
+    elif 'darknet' in args.model:
+        model = build_darknet(args)
+    elif 'gelan' in args.model:
+        model = build_gelan(args)
+    else:
+        raise NotImplementedError("Unknown model: {}".format(args.model))
+
+    return model

+ 18 - 0
iclab/models/cspdarknet/build.py

@@ -0,0 +1,18 @@
+from .cspdarknet import cspdarknet_n, cspdarknet_s, cspdarknet_m, cspdarknet_l, cspdarknet_x
+
+def build_cspdarknet(args):
+    # build vit model
+    if   args.model == 'cspdarknet_n':
+        model = cspdarknet_n(args.img_dim, args.num_classes)
+    elif args.model == 'cspdarknet_s':
+        model = cspdarknet_s(args.img_dim, args.num_classes)
+    elif args.model == 'cspdarknet_m':
+        model = cspdarknet_m(args.img_dim, args.num_classes)
+    elif args.model == 'cspdarknet_l':
+        model = cspdarknet_l(args.img_dim, args.num_classes)
+    elif args.model == 'cspdarknet_x':
+        model = cspdarknet_x(args.img_dim, args.num_classes)
+    else:
+        raise NotImplementedError("Unknown cspdarknet: {}".format(args.model))
+    
+    return model

+ 170 - 0
iclab/models/cspdarknet/cspdarknet.py

@@ -0,0 +1,170 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .modules import BasicConv, CSPBlock
+except:
+    from  modules import BasicConv, CSPBlock
+
+
+# ---------------------------- CSPDarkNet ----------------------------
+# CSPDarkNet
+class CSPDarkNet(nn.Module):
+    def __init__(self, img_dim=3, width=1.0, depth=1.0, act_type='silu', norm_type='BN', depthwise=False, num_classes=1000):
+        super(CSPDarkNet, self).__init__()
+        # ---------------- Basic parameters ----------------
+        self.width_factor = width
+        self.depth_factor = depth
+        self.feat_dims = [round(64 * width),
+                          round(128 * width),
+                          round(256 * width),
+                          round(512 * width),
+                          round(1024 * width)
+                          ]
+
+        # ---------------- Model parameters ----------------
+        ## P1/2
+        self.layer_1 = BasicConv(img_dim, self.feat_dims[0],
+                                 kernel_size=6, padding=2, stride=2,
+                                 act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        
+        ## P2/4
+        self.layer_2 = nn.Sequential(
+            BasicConv(self.feat_dims[0], self.feat_dims[1],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            CSPBlock(self.feat_dims[1],
+                     self.feat_dims[1],
+                     num_blocks   = round(3*depth),
+                     expand_ratio = 0.5,
+                     shortcut     = True,
+                     act_type     = act_type,
+                     norm_type    = norm_type,
+                     depthwise    = depthwise)
+        )
+        # P3/8
+        self.layer_3 = nn.Sequential(
+            BasicConv(self.feat_dims[1], self.feat_dims[2],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            CSPBlock(self.feat_dims[2],
+                     self.feat_dims[2],
+                     num_blocks   = round(9*depth),
+                     expand_ratio = 0.5,
+                     shortcut     = True,
+                     act_type     = act_type,
+                     norm_type    = norm_type,
+                     depthwise    = depthwise)
+        )
+        # P4/16
+        self.layer_4 = nn.Sequential(
+            BasicConv(self.feat_dims[2], self.feat_dims[3],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            CSPBlock(self.feat_dims[3],
+                     self.feat_dims[3],
+                     num_blocks   = round(9*depth),
+                     expand_ratio = 0.5,
+                     shortcut     = True,
+                     act_type     = act_type,
+                     norm_type    = norm_type,
+                     depthwise    = depthwise)
+        )
+        # P5/32
+        self.layer_5 = nn.Sequential(
+            BasicConv(self.feat_dims[3], self.feat_dims[4],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            CSPBlock(self.feat_dims[4],
+                     self.feat_dims[4],
+                     num_blocks   = round(3*depth),
+                     expand_ratio = 0.5,
+                     shortcut     = True,
+                     act_type     = act_type,
+                     norm_type    = norm_type,
+                     depthwise    = depthwise)
+        )
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+        self.fc = nn.Linear(self.feat_dims[4], num_classes)
+
+
+    def forward(self, x):
+        c1 = self.layer_1(x)
+        c2 = self.layer_2(c1)
+        c3 = self.layer_3(c2)
+        c4 = self.layer_4(c3)
+        c5 = self.layer_5(c4)
+
+        c5 = self.avgpool(c5)
+        c5 = torch.flatten(c5, 1)
+        c5 = self.fc(c5)
+
+        return c5
+
+
+# ---------------------------- Functions ----------------------------
+## build ELAN-Net
+# ------------------------ Model Functions ------------------------
+def cspdarknet_n(img_dim=3, num_classes=1000) -> CSPDarkNet:
+    return CSPDarkNet(img_dim=img_dim,
+                       width=0.25,
+                       depth=0.34,
+                       act_type='silu',
+                       norm_type='BN',
+                       depthwise=False,
+                       num_classes=num_classes
+                       )
+
+def cspdarknet_s(img_dim=3, num_classes=1000) -> CSPDarkNet:
+    return CSPDarkNet(img_dim=img_dim,
+                       width=0.50,
+                       depth=0.34,
+                       act_type='silu',
+                       norm_type='BN',
+                       depthwise=False,
+                       num_classes=num_classes
+                       )
+
+def cspdarknet_m(img_dim=3, num_classes=1000) -> CSPDarkNet:
+    return CSPDarkNet(img_dim=img_dim,
+                       width=0.75,
+                       depth=0.67,
+                       act_type='silu',
+                       norm_type='BN',
+                       depthwise=False,
+                       num_classes=num_classes
+                       )
+
+def cspdarknet_l(img_dim=3, num_classes=1000) -> CSPDarkNet:
+    return CSPDarkNet(img_dim=img_dim,
+                       width=1.0,
+                       depth=1.0,
+                       act_type='silu',
+                       norm_type='BN',
+                       depthwise=False,
+                       num_classes=num_classes
+                       )
+
+def cspdarknet_x(img_dim=3, num_classes=1000) -> CSPDarkNet:
+    return CSPDarkNet(img_dim=img_dim,
+                       width=1.25,
+                       depth=1.34,
+                       act_type='silu',
+                       norm_type='BN',
+                       depthwise=False,
+                       num_classes=num_classes
+                       )
+
+
+if __name__ == '__main__':
+    import torch
+    from thop import profile
+
+    # build model
+    model = cspdarknet_s()
+
+    x = torch.randn(1, 3, 224, 224)
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 136 - 0
iclab/models/cspdarknet/modules.py

@@ -0,0 +1,136 @@
+import torch
+import torch.nn as nn
+from   typing import List
+
+# --------------------- Basic modules ---------------------
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+        
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+    elif norm_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+
+class BasicConv(nn.Module):
+    def __init__(self, 
+                 in_dim,                   # in channels
+                 out_dim,                  # out channels 
+                 kernel_size=1,            # kernel size 
+                 padding=0,                # padding
+                 stride=1,                 # padding
+                 dilation=1,               # dilation
+                 act_type  :str = 'lrelu', # activation
+                 norm_type :str = 'BN',    # normalization
+                 depthwise :bool = False
+                ):
+        super(BasicConv, self).__init__()
+        self.depthwise = depthwise
+        if not depthwise:
+            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=1)
+            self.norm = get_norm(norm_type, out_dim)
+        else:
+            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim)
+            self.norm1 = get_norm(norm_type, in_dim)
+            self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1)
+            self.norm2 = get_norm(norm_type, out_dim)
+        self.act  = get_activation(act_type)
+
+    def forward(self, x):
+        if not self.depthwise:
+            return self.act(self.norm(self.conv(x)))
+        else:
+            # Depthwise conv
+            x = self.norm1(self.conv1(x))
+            # Pointwise conv
+            x = self.act(self.norm2(self.conv2(x)))
+            return x
+
+
+# ---------------------------- Basic Modules ----------------------------
+class YoloBottleneck(nn.Module):
+    def __init__(self,
+                 in_dim       :int,
+                 out_dim      :int,
+                 kernel_size  :List  = [1, 3],
+                 expand_ratio :float = 0.5,
+                 shortcut     :bool  = False,
+                 act_type     :str   = 'silu',
+                 norm_type    :str   = 'BN',
+                 depthwise    :bool  = False,
+                 ) -> None:
+        super(YoloBottleneck, self).__init__()
+        inter_dim = int(out_dim * expand_ratio)
+        # ----------------- Network setting -----------------
+        self.conv_layer1 = BasicConv(in_dim, inter_dim,
+                                     kernel_size=kernel_size[0], padding=kernel_size[0]//2, stride=1,
+                                     act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.conv_layer2 = BasicConv(inter_dim, out_dim,
+                                     kernel_size=kernel_size[1], padding=kernel_size[1]//2, stride=1,
+                                     act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.shortcut = shortcut and in_dim == out_dim
+
+    def forward(self, x):
+        h = self.conv_layer2(self.conv_layer1(x))
+
+        return x + h if self.shortcut else h
+
+class CSPBlock(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 num_blocks   :int   = 1,
+                 expand_ratio :float = 0.5,
+                 shortcut     :bool  = False,
+                 act_type     :str   = 'silu',
+                 norm_type    :str   = 'BN',
+                 depthwise    :bool  = False,
+                 ):
+        super(CSPBlock, self).__init__()
+        # ---------- Basic parameters ----------
+        self.num_blocks = num_blocks
+        self.expand_ratio = expand_ratio
+        self.shortcut = shortcut
+        inter_dim = round(out_dim * expand_ratio)
+        # ---------- Model parameters ----------
+        self.conv_layer_1 = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.conv_layer_2 = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.conv_layer_3 = BasicConv(inter_dim * 2, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.module       = nn.Sequential(*[YoloBottleneck(inter_dim,
+                                                           inter_dim,
+                                                           kernel_size  = [1, 3],
+                                                           expand_ratio = 1.0,
+                                                           shortcut     = shortcut,
+                                                           act_type     = act_type,
+                                                           norm_type    = norm_type,
+                                                           depthwise    = depthwise)
+                                                           for _ in range(num_blocks)
+                                                           ])
+
+    def forward(self, x):
+        x1 = self.conv_layer_1(x)
+        x2 = self.module(self.conv_layer_2(x))
+        out = self.conv_layer_3(torch.cat([x1, x2], dim=1))
+
+        return out
+    

+ 18 - 0
iclab/models/darknet/build.py

@@ -0,0 +1,18 @@
+from .darknet import darknet_n, darknet_s, darknet_m, darknet_l, darknet_x
+
+def build_darknet(args):
+    # build vit model
+    if   args.model == 'darknet_n':
+        model = darknet_n(args.img_dim, args.num_classes)
+    elif args.model == 'darknet_s':
+        model = darknet_s(args.img_dim, args.num_classes)
+    elif args.model == 'darknet_m':
+        model = darknet_m(args.img_dim, args.num_classes)
+    elif args.model == 'darknet_l':
+        model = darknet_l(args.img_dim, args.num_classes)
+    elif args.model == 'darknet_x':
+        model = darknet_x(args.img_dim, args.num_classes)
+    else:
+        raise NotImplementedError("Unknown darknet: {}".format(args.model))
+    
+    return model

+ 170 - 0
iclab/models/darknet/darknet.py

@@ -0,0 +1,170 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .modules import BasicConv, ResBlock
+except:
+    from  modules import BasicConv, ResBlock
+
+
+# ---------------------------- DarkNet ----------------------------
+# Modified DarkNet
+class DarkNet(nn.Module):
+    def __init__(self, img_dim=3, width=1.0, depth=1.0, act_type='silu', norm_type='BN', depthwise=False, num_classes=1000):
+        super(DarkNet, self).__init__()
+        # ---------------- Basic parameters ----------------
+        self.width_factor = width
+        self.depth_factor = depth
+        self.feat_dims = [round(64 * width),
+                          round(128 * width),
+                          round(256 * width),
+                          round(512 * width),
+                          round(1024 * width)
+                          ]
+
+        # ---------------- Model parameters ----------------
+        ## P1/2
+        self.layer_1 = BasicConv(img_dim, self.feat_dims[0],
+                                 kernel_size=6, padding=2, stride=2,
+                                 act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        
+        ## P2/4
+        self.layer_2 = nn.Sequential(
+            BasicConv(self.feat_dims[0], self.feat_dims[1],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            ResBlock(self.feat_dims[1],
+                     self.feat_dims[1],
+                     num_blocks   = round(3*depth),
+                     expand_ratio = 0.5,
+                     shortcut     = True,
+                     act_type     = act_type,
+                     norm_type    = norm_type,
+                     depthwise    = depthwise)
+        )
+        # P3/8
+        self.layer_3 = nn.Sequential(
+            BasicConv(self.feat_dims[1], self.feat_dims[2],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            ResBlock(self.feat_dims[2],
+                     self.feat_dims[2],
+                     num_blocks   = round(9*depth),
+                     expand_ratio = 0.5,
+                     shortcut     = True,
+                     act_type     = act_type,
+                     norm_type    = norm_type,
+                     depthwise    = depthwise)
+        )
+        # P4/16
+        self.layer_4 = nn.Sequential(
+            BasicConv(self.feat_dims[2], self.feat_dims[3],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            ResBlock(self.feat_dims[3],
+                     self.feat_dims[3],
+                     num_blocks   = round(9*depth),
+                     expand_ratio = 0.5,
+                     shortcut     = True,
+                     act_type     = act_type,
+                     norm_type    = norm_type,
+                     depthwise    = depthwise)
+        )
+        # P5/32
+        self.layer_5 = nn.Sequential(
+            BasicConv(self.feat_dims[3], self.feat_dims[4],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            ResBlock(self.feat_dims[4],
+                     self.feat_dims[4],
+                     num_blocks   = round(3*depth),
+                     expand_ratio = 0.5,
+                     shortcut     = True,
+                     act_type     = act_type,
+                     norm_type    = norm_type,
+                     depthwise    = depthwise)
+        )
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+        self.fc = nn.Linear(self.feat_dims[4], num_classes)
+
+
+    def forward(self, x):
+        c1 = self.layer_1(x)
+        c2 = self.layer_2(c1)
+        c3 = self.layer_3(c2)
+        c4 = self.layer_4(c3)
+        c5 = self.layer_5(c4)
+
+        c5 = self.avgpool(c5)
+        c5 = torch.flatten(c5, 1)
+        c5 = self.fc(c5)
+
+        return c5
+
+
+# ---------------------------- Functions ----------------------------
+## build ELAN-Net
+# ------------------------ Model Functions ------------------------
+def darknet_n(img_dim=3, num_classes=1000) -> DarkNet:
+    return DarkNet(img_dim=img_dim,
+                       width=0.25,
+                       depth=0.34,
+                       act_type='silu',
+                       norm_type='BN',
+                       depthwise=False,
+                       num_classes=num_classes
+                       )
+
+def darknet_s(img_dim=3, num_classes=1000) -> DarkNet:
+    return DarkNet(img_dim=img_dim,
+                       width=0.50,
+                       depth=0.34,
+                       act_type='silu',
+                       norm_type='BN',
+                       depthwise=False,
+                       num_classes=num_classes
+                       )
+
+def darknet_m(img_dim=3, num_classes=1000) -> DarkNet:
+    return DarkNet(img_dim=img_dim,
+                       width=0.75,
+                       depth=0.67,
+                       act_type='silu',
+                       norm_type='BN',
+                       depthwise=False,
+                       num_classes=num_classes
+                       )
+
+def darknet_l(img_dim=3, num_classes=1000) -> DarkNet:
+    return DarkNet(img_dim=img_dim,
+                       width=1.0,
+                       depth=1.0,
+                       act_type='silu',
+                       norm_type='BN',
+                       depthwise=False,
+                       num_classes=num_classes
+                       )
+
+def darknet_x(img_dim=3, num_classes=1000) -> DarkNet:
+    return DarkNet(img_dim=img_dim,
+                       width=1.25,
+                       depth=1.34,
+                       act_type='silu',
+                       norm_type='BN',
+                       depthwise=False,
+                       num_classes=num_classes
+                       )
+
+
+if __name__ == '__main__':
+    import torch
+    from thop import profile
+
+    # build model
+    model = darknet_s()
+
+    x = torch.randn(1, 3, 224, 224)
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 130 - 0
iclab/models/darknet/modules.py

@@ -0,0 +1,130 @@
+import torch
+import torch.nn as nn
+from   typing import List
+
+# --------------------- Basic modules ---------------------
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+        
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+    elif norm_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+
+class BasicConv(nn.Module):
+    def __init__(self, 
+                 in_dim,                   # in channels
+                 out_dim,                  # out channels 
+                 kernel_size=1,            # kernel size 
+                 padding=0,                # padding
+                 stride=1,                 # padding
+                 dilation=1,               # dilation
+                 act_type  :str = 'lrelu', # activation
+                 norm_type :str = 'BN',    # normalization
+                 depthwise :bool = False
+                ):
+        super(BasicConv, self).__init__()
+        self.depthwise = depthwise
+        if not depthwise:
+            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=1)
+            self.norm = get_norm(norm_type, out_dim)
+        else:
+            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim)
+            self.norm1 = get_norm(norm_type, in_dim)
+            self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1)
+            self.norm2 = get_norm(norm_type, out_dim)
+        self.act  = get_activation(act_type)
+
+    def forward(self, x):
+        if not self.depthwise:
+            return self.act(self.norm(self.conv(x)))
+        else:
+            # Depthwise conv
+            x = self.norm1(self.conv1(x))
+            # Pointwise conv
+            x = self.act(self.norm2(self.conv2(x)))
+            return x
+
+
+# ---------------------------- Basic Modules ----------------------------
+class YoloBottleneck(nn.Module):
+    def __init__(self,
+                 in_dim       :int,
+                 out_dim      :int,
+                 kernel_size  :List  = [1, 3],
+                 expand_ratio :float = 0.5,
+                 shortcut     :bool  = False,
+                 act_type     :str   = 'silu',
+                 norm_type    :str   = 'BN',
+                 depthwise    :bool  = False,
+                 ) -> None:
+        super(YoloBottleneck, self).__init__()
+        inter_dim = int(out_dim * expand_ratio)
+        # ----------------- Network setting -----------------
+        self.conv_layer1 = BasicConv(in_dim, inter_dim,
+                                     kernel_size=kernel_size[0], padding=kernel_size[0]//2, stride=1,
+                                     act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.conv_layer2 = BasicConv(inter_dim, out_dim,
+                                     kernel_size=kernel_size[1], padding=kernel_size[1]//2, stride=1,
+                                     act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.shortcut = shortcut and in_dim == out_dim
+
+    def forward(self, x):
+        h = self.conv_layer2(self.conv_layer1(x))
+
+        return x + h if self.shortcut else h
+
+class ResBlock(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 num_blocks   :int   = 1,
+                 expand_ratio :float = 0.5,
+                 shortcut     :bool  = False,
+                 act_type     :str   = 'silu',
+                 norm_type    :str   = 'BN',
+                 depthwise    :bool  = False,
+                 ):
+        super(ResBlock, self).__init__()
+        # ---------- Basic parameters ----------
+        self.num_blocks = num_blocks
+        self.expand_ratio = expand_ratio
+        self.shortcut = shortcut
+        # ---------- Model parameters ----------
+        self.module = nn.Sequential(*[YoloBottleneck(in_dim,
+                                                     out_dim,
+                                                     kernel_size  = [1, 3],
+                                                     expand_ratio = expand_ratio,
+                                                     shortcut     = shortcut,
+                                                     act_type     = act_type,
+                                                     norm_type    = norm_type,
+                                                     depthwise    = depthwise)
+                                                     for _ in range(num_blocks)
+                                                     ])
+
+    def forward(self, x):
+        out = self.module(x)
+
+        return out
+    

+ 16 - 0
iclab/models/elandarknet/build.py

@@ -0,0 +1,16 @@
+from .elandarknet import elandarknet_n, elandarknet_s, elandarknet_m, elandarknet_l, elandarknet_x
+
+def build_elandarknet(args):
+    # build vit model
+    if   args.model == 'elandarknet_n':
+        model = elandarknet_n(args.img_dim, args.num_classes)
+    elif args.model == 'elandarknet_s':
+        model = elandarknet_s(args.img_dim, args.num_classes)
+    elif args.model == 'elandarknet_m':
+        model = elandarknet_m(args.img_dim, args.num_classes)
+    elif args.model == 'elandarknet_l':
+        model = elandarknet_l(args.img_dim, args.num_classes)
+    elif args.model == 'elandarknet_x':
+        model = elandarknet_x(args.img_dim, args.num_classes)
+    
+    return model

+ 171 - 0
iclab/models/elandarknet/elandarknet.py

@@ -0,0 +1,171 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .modules import BasicConv, ELANLayer
+except:
+    from  modules import BasicConv, ELANLayer
+   
+
+## ELAN-based DarkNet
+class ELANDarkNet(nn.Module):
+    def __init__(self, img_dim=3, width=1.0, depth=1.0, ratio=1.0, num_classes=1000, act_type='silu', norm_type='BN', depthwise=False):
+        super(ELANDarkNet, self).__init__()
+        # ---------------- Basic parameters ----------------
+        self.width_factor = width
+        self.depth_factor = depth
+        self.last_stage_factor = ratio
+        self.feat_dims = [round(64 * width),
+                          round(128 * width),
+                          round(256 * width),
+                          round(512 * width),
+                          round(512 * width * ratio)
+                          ]
+        # ---------------- Network parameters ----------------
+        ## P1/2
+        self.layer_1 = BasicConv(img_dim, self.feat_dims[0],
+                                 kernel_size=3, padding=1, stride=2,
+                                 act_type=act_type, norm_type=norm_type)
+        ## P2/4
+        self.layer_2 = nn.Sequential(
+            BasicConv(self.feat_dims[0], self.feat_dims[1],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            ELANLayer(in_dim     = self.feat_dims[1],
+                      out_dim    = self.feat_dims[1],
+                      num_blocks = round(3*depth),
+                      shortcut   = True,
+                      act_type   = act_type,
+                      norm_type  = norm_type,
+                      depthwise  = depthwise,
+                      )
+        )
+        ## P3/8
+        self.layer_3 = nn.Sequential(
+            BasicConv(self.feat_dims[1], self.feat_dims[2],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            ELANLayer(in_dim     = self.feat_dims[2],
+                      out_dim    = self.feat_dims[2],
+                      num_blocks = round(6*depth),
+                      shortcut   = True,
+                      act_type   = act_type,
+                      norm_type  = norm_type,
+                      depthwise  = depthwise,
+                      )
+        )
+        ## P4/16
+        self.layer_4 = nn.Sequential(
+            BasicConv(self.feat_dims[2], self.feat_dims[3],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            ELANLayer(in_dim     = self.feat_dims[3],
+                      out_dim    = self.feat_dims[3],
+                      num_blocks = round(6*depth),
+                      shortcut   = True,
+                      act_type   = act_type,
+                      norm_type  = norm_type,
+                      depthwise  = depthwise,
+                      )
+        )
+        ## P5/32
+        self.layer_5 = nn.Sequential(
+            BasicConv(self.feat_dims[3], self.feat_dims[4],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            ELANLayer(in_dim     = self.feat_dims[4],
+                      out_dim    = self.feat_dims[4],
+                      num_blocks = round(3*depth),
+                      shortcut   = True,
+                      act_type   = act_type,
+                      norm_type  = norm_type,
+                      depthwise  = depthwise,
+                      )
+        )
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+        self.fc = nn.Linear(self.feat_dims[4], num_classes)
+
+
+    def forward(self, x):
+        c1 = self.layer_1(x)
+        c2 = self.layer_2(c1)
+        c3 = self.layer_3(c2)
+        c4 = self.layer_4(c3)
+        c5 = self.layer_5(c4)
+
+        c5 = self.avgpool(c5)
+        c5 = torch.flatten(c5, 1)
+        c5 = self.fc(c5)
+
+        return c5
+
+
+# ------------------------ Model Functions ------------------------
+def elandarknet_n(img_dim=3, num_classes=1000) -> ELANDarkNet:
+    return ELANDarkNet(img_dim=img_dim,
+                       width=0.25,
+                       depth=0.34,
+                       ratio=2.0,
+                       act_type='silu',
+                       norm_type='BN',
+                       depthwise=False,
+                       num_classes=num_classes
+                       )
+
+def elandarknet_s(img_dim=3, num_classes=1000) -> ELANDarkNet:
+    return ELANDarkNet(img_dim=img_dim,
+                       width=0.50,
+                       depth=0.34,
+                       ratio=2.0,
+                       act_type='silu',
+                       norm_type='BN',
+                       depthwise=False,
+                       num_classes=num_classes
+                       )
+
+def elandarknet_m(img_dim=3, num_classes=1000) -> ELANDarkNet:
+    return ELANDarkNet(img_dim=img_dim,
+                       width=0.75,
+                       depth=0.67,
+                       ratio=1.5,
+                       act_type='silu',
+                       norm_type='BN',
+                       depthwise=False,
+                       num_classes=num_classes
+                       )
+
+def elandarknet_l(img_dim=3, num_classes=1000) -> ELANDarkNet:
+    return ELANDarkNet(img_dim=img_dim,
+                       width=1.0,
+                       depth=1.0,
+                       ratio=1.0,
+                       act_type='silu',
+                       norm_type='BN',
+                       depthwise=False,
+                       num_classes=num_classes
+                       )
+
+def elandarknet_x(img_dim=3, num_classes=1000) -> ELANDarkNet:
+    return ELANDarkNet(img_dim=img_dim,
+                       width=1.25,
+                       depth=1.34,
+                       ratio=1.0,
+                       act_type='silu',
+                       norm_type='BN',
+                       depthwise=False,
+                       num_classes=num_classes
+                       )
+
+
+if __name__ == '__main__':
+    import torch
+    from thop import profile
+
+    # build model
+    model = elandarknet_s()
+
+    x = torch.randn(1, 3, 224, 224)
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 135 - 0
iclab/models/elandarknet/modules.py

@@ -0,0 +1,135 @@
+import torch
+import torch.nn as nn
+from   typing import List
+
+# --------------------- Basic modules ---------------------
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+        
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+    elif norm_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+
+class BasicConv(nn.Module):
+    def __init__(self, 
+                 in_dim,                   # in channels
+                 out_dim,                  # out channels 
+                 kernel_size=1,            # kernel size 
+                 padding=0,                # padding
+                 stride=1,                 # padding
+                 dilation=1,               # dilation
+                 act_type  :str = 'lrelu', # activation
+                 norm_type :str = 'BN',    # normalization
+                 depthwise :bool = False
+                ):
+        super(BasicConv, self).__init__()
+        self.depthwise = depthwise
+        if not depthwise:
+            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=1)
+            self.norm = get_norm(norm_type, out_dim)
+        else:
+            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim)
+            self.norm1 = get_norm(norm_type, in_dim)
+            self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1)
+            self.norm2 = get_norm(norm_type, out_dim)
+        self.act  = get_activation(act_type)
+
+    def forward(self, x):
+        if not self.depthwise:
+            return self.act(self.norm(self.conv(x)))
+        else:
+            # Depthwise conv
+            x = self.norm1(self.conv1(x))
+            # Pointwise conv
+            x = self.act(self.norm2(self.conv2(x)))
+            return x
+
+
+# --------------------- Yolov8 modules ---------------------
+class Bottleneck(nn.Module):
+    def __init__(self,
+                 in_dim       :int,
+                 out_dim      :int,
+                 kernel_size  :List  = [1, 3],
+                 expand_ratio :float = 0.5,
+                 shortcut     :bool  = False,
+                 act_type     :str   = 'silu',
+                 norm_type    :str   = 'BN',
+                 depthwise    :bool  = False,
+                 ) -> None:
+        super(Bottleneck, self).__init__()
+        inter_dim = int(out_dim * expand_ratio)
+        # ----------------- Network setting -----------------
+        self.conv_layer1 = BasicConv(in_dim, inter_dim,
+                                     kernel_size=kernel_size[0], padding=kernel_size[0]//2, stride=1,
+                                     act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.conv_layer2 = BasicConv(inter_dim, out_dim,
+                                     kernel_size=kernel_size[1], padding=kernel_size[1]//2, stride=1,
+                                     act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.shortcut = shortcut and in_dim == out_dim
+
+    def forward(self, x):
+        h = self.conv_layer2(self.conv_layer1(x))
+
+        return x + h if self.shortcut else h
+
+class ELANLayer(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 expand_ratio :float = 0.5,
+                 num_blocks   :int   = 1,
+                 shortcut     :bool  = False,
+                 act_type     :str   = 'silu',
+                 norm_type    :str   = 'BN',
+                 depthwise    :bool  = False,
+                 ) -> None:
+        super(ELANLayer, self).__init__()
+        self.inter_dim = round(out_dim * expand_ratio)
+        self.input_proj  = BasicConv(in_dim, self.inter_dim * 2, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.output_proj = BasicConv((2 + num_blocks) * self.inter_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.module = nn.ModuleList([Bottleneck(self.inter_dim,
+                                                    self.inter_dim,
+                                                    kernel_size  = [3, 3],
+                                                    expand_ratio = 1.0,
+                                                    shortcut     = shortcut,
+                                                    act_type     = act_type,
+                                                    norm_type    = norm_type,
+                                                    depthwise    = depthwise)
+                                                    for _ in range(num_blocks)])
+
+    def forward(self, x):
+        # Input proj
+        x1, x2 = torch.chunk(self.input_proj(x), 2, dim=1)
+        out = list([x1, x2])
+
+        # Bottlenecl
+        out.extend(m(out[-1]) for m in self.module)
+
+        # Output proj
+        out = self.output_proj(torch.cat(out, dim=1))
+
+        return out
+   

+ 12 - 0
iclab/models/gelan/build.py

@@ -0,0 +1,12 @@
+from .gelan import gelan_s, gelan_c
+
+def build_gelan(args):
+    # build vit model
+    if   args.model == 'gelan_s':
+        model = gelan_s(args.img_dim, args.num_classes)
+    elif args.model == 'gelan_c':
+        model = gelan_c(args.img_dim, args.num_classes)
+    else:
+        raise NotImplementedError("Unknown elannet: {}".format(args.model))
+    
+    return model

+ 233 - 0
iclab/models/gelan/gelan.py

@@ -0,0 +1,233 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .modules import BasicConv, RepGElanLayer, ADown
+except:
+    from  modules import BasicConv, RepGElanLayer, ADown
+
+
+# ---------------------------- GELAN Backbone ----------------------------
+class GElanCBackbone(nn.Module):
+    def __init__(self, img_dim=3, num_classes=1000, act_type='silu', norm_type='BN', depthwise=False):
+        super(GElanCBackbone, self).__init__()
+        # ------------------ Basic setting ------------------
+        self.feat_dims = {
+            "c1": [64],
+            "c2": [128, [128, 64],  256],
+            "c3": [256, [256, 128], 512],
+            "c4": [512, [512, 256], 512],
+            "c5": [512, [512, 256], 512],
+        }
+        
+        # ------------------ Network setting ------------------
+        ## P1/2
+        self.layer_1 = BasicConv(img_dim, self.feat_dims["c1"][0],
+                                 kernel_size=3, padding=1, stride=2,
+                                 act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        # P2/4
+        self.layer_2 = nn.Sequential(
+            BasicConv(self.feat_dims["c1"][0], self.feat_dims["c2"][0],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            RepGElanLayer(in_dim     = self.feat_dims["c2"][0],
+                          inter_dims = self.feat_dims["c2"][1],
+                          out_dim    = self.feat_dims["c2"][2],
+                          num_blocks = 1,
+                          shortcut   = True,
+                          act_type   = act_type,
+                          norm_type  = norm_type,
+                          depthwise  = depthwise)
+        )
+        # P3/8
+        self.layer_3 = nn.Sequential(
+            ADown(self.feat_dims["c2"][2], self.feat_dims["c3"][0],
+                  act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            RepGElanLayer(in_dim     = self.feat_dims["c3"][0],
+                          inter_dims = self.feat_dims["c3"][1],
+                          out_dim    = self.feat_dims["c3"][2],
+                          num_blocks = 1,
+                          shortcut   = True,
+                          act_type   = act_type,
+                          norm_type  = norm_type,
+                          depthwise  = depthwise)
+        )
+        # P4/16
+        self.layer_4 = nn.Sequential(
+            ADown(self.feat_dims["c3"][2], self.feat_dims["c4"][0],
+                  act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            RepGElanLayer(in_dim     = self.feat_dims["c4"][0],
+                          inter_dims = self.feat_dims["c4"][1],
+                          out_dim    = self.feat_dims["c4"][2],
+                          num_blocks = 1,
+                          shortcut   = True,
+                          act_type   = act_type,
+                          norm_type  = norm_type,
+                          depthwise  = depthwise)
+        )
+        # P5/32
+        self.layer_5 = nn.Sequential(
+            ADown(self.feat_dims["c4"][2], self.feat_dims["c5"][0],
+                  act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            RepGElanLayer(in_dim     = self.feat_dims["c5"][0],
+                          inter_dims = self.feat_dims["c5"][1],
+                          out_dim    = self.feat_dims["c5"][2],
+                          num_blocks = 1,
+                          shortcut   = True,
+                          act_type   = act_type,
+                          norm_type  = norm_type,
+                          depthwise  = depthwise)
+        )
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+        self.fc = nn.Linear(self.feat_dims["c5"][2], num_classes)
+
+        # Initialize all layers
+        self.init_weights()
+        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
+
+    def forward(self, x):
+        c1 = self.layer_1(x)
+        c2 = self.layer_2(c1)
+        c3 = self.layer_3(c2)
+        c4 = self.layer_4(c3)
+        c5 = self.layer_5(c4)
+
+        c5 = self.avgpool(c5)
+        c5 = torch.flatten(c5, 1)
+        c5 = self.fc(c5)
+
+        return c5
+
+class GElanSBackbone(nn.Module):
+    def __init__(self, img_dim=3, num_classes=1000, act_type='silu', norm_type='BN', depthwise=False):
+        super(GElanSBackbone, self).__init__()
+        # ------------------ Basic setting ------------------
+        self.feat_dims = {
+            "c1": [32],
+            "c2": [64,  [64, 32],   64],
+            "c3": [64,  [64, 32],   128],
+            "c4": [128, [128, 64],  256],
+            "c5": [256, [256, 128], 256],
+        }
+        
+        # ------------------ Network setting ------------------
+        ## P1/2
+        self.layer_1 = BasicConv(img_dim, self.feat_dims["c1"][0],
+                                 kernel_size=3, padding=1, stride=2,
+                                 act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        # P2/4
+        self.layer_2 = nn.Sequential(
+            BasicConv(self.feat_dims["c1"][0], self.feat_dims["c2"][0],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            RepGElanLayer(in_dim     = self.feat_dims["c2"][0],
+                          inter_dims = self.feat_dims["c2"][1],
+                          out_dim    = self.feat_dims["c2"][2],
+                          num_blocks = 3,
+                          shortcut   = True,
+                          act_type   = act_type,
+                          norm_type  = norm_type,
+                          depthwise  = depthwise)
+        )
+        # P3/8
+        self.layer_3 = nn.Sequential(
+            ADown(self.feat_dims["c2"][2], self.feat_dims["c3"][0],
+                  act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            RepGElanLayer(in_dim     = self.feat_dims["c3"][0],
+                          inter_dims = self.feat_dims["c3"][1],
+                          out_dim    = self.feat_dims["c3"][2],
+                          num_blocks = 3,
+                          shortcut   = True,
+                          act_type   = act_type,
+                          norm_type  = norm_type,
+                          depthwise  = depthwise)
+        )
+        # P4/16
+        self.layer_4 = nn.Sequential(
+            ADown(self.feat_dims["c3"][2], self.feat_dims["c4"][0],
+                  act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            RepGElanLayer(in_dim     = self.feat_dims["c4"][0],
+                          inter_dims = self.feat_dims["c4"][1],
+                          out_dim    = self.feat_dims["c4"][2],
+                          num_blocks = 3,
+                          shortcut   = True,
+                          act_type   = act_type,
+                          norm_type  = norm_type,
+                          depthwise  = depthwise)
+        )
+        # P5/32
+        self.layer_5 = nn.Sequential(
+            ADown(self.feat_dims["c4"][2], self.feat_dims["c5"][0],
+                  act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            RepGElanLayer(in_dim     = self.feat_dims["c5"][0],
+                          inter_dims = self.feat_dims["c5"][1],
+                          out_dim    = self.feat_dims["c5"][2],
+                          num_blocks = 3,
+                          shortcut   = True,
+                          act_type   = act_type,
+                          norm_type  = norm_type,
+                          depthwise  = depthwise)
+        )
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+        self.fc = nn.Linear(self.feat_dims["c5"][2], num_classes)
+
+        # Initialize all layers
+        self.init_weights()
+        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
+
+    def forward(self, x):
+        c1 = self.layer_1(x)
+        c2 = self.layer_2(c1)
+        c3 = self.layer_3(c2)
+        c4 = self.layer_4(c3)
+        c5 = self.layer_5(c4)
+
+        c5 = self.avgpool(c5)
+        c5 = torch.flatten(c5, 1)
+        c5 = self.fc(c5)
+
+        return c5
+
+
+# ---------------------------- Functions ----------------------------
+def gelan_c(img_dim=3, num_classes=1000):
+    return GElanCBackbone(img_dim,
+                          num_classes=num_classes,
+                          act_type='silu',
+                          norm_type='BN',
+                          depthwise=False)
+
+def gelan_s(img_dim=3, num_classes=1000):
+    return GElanSBackbone(img_dim,
+                          num_classes=num_classes,
+                          act_type='silu',
+                          norm_type='BN',
+                          depthwise=False)
+
+
+if __name__ == '__main__':
+    import torch
+    from thop import profile
+
+    # build model
+    model = gelan_c()
+
+    x = torch.randn(1, 3, 224, 224)
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 312 - 0
iclab/models/gelan/modules.py

@@ -0,0 +1,312 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from typing import List
+
+
+# --------------------- Basic modules ---------------------
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+        
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+    elif norm_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+
+class BasicConv(nn.Module):
+    def __init__(self, 
+                 in_dim,                   # in channels
+                 out_dim,                  # out channels 
+                 kernel_size=1,            # kernel size 
+                 padding=0,                # padding
+                 stride=1,                 # padding
+                 dilation=1,               # dilation
+                 group=1,                  # group
+                 act_type  :str = 'lrelu', # activation
+                 norm_type :str = 'BN',    # normalization
+                 depthwise :bool = False
+                ):
+        super(BasicConv, self).__init__()
+        self.depthwise = depthwise
+        if not depthwise:
+            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=group)
+            self.norm = get_norm(norm_type, out_dim)
+        else:
+            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim)
+            self.norm1 = get_norm(norm_type, in_dim)
+            self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1)
+            self.norm2 = get_norm(norm_type, out_dim)
+        self.act  = get_activation(act_type)
+
+    def forward(self, x):
+        if not self.depthwise:
+            return self.act(self.norm(self.conv(x)))
+        else:
+            # Depthwise conv
+            x = self.norm1(self.conv1(x))
+            # Pointwise conv
+            x = self.act(self.norm2(self.conv2(x)))
+            return x
+
+
+# --------------------- GELAN modules (from yolov9) ---------------------
+class ADown(nn.Module):
+    def __init__(self, in_dim, out_dim, act_type="silu", norm_type="BN", depthwise=False):
+        super().__init__()
+        inter_dim = out_dim // 2
+        self.conv_layer_1 = BasicConv(in_dim // 2, inter_dim,
+                                    kernel_size=3, padding=1, stride=2,
+                                    act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.conv_layer_2 = BasicConv(in_dim // 2, inter_dim, kernel_size=1,
+                                    act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+    def forward(self, x):
+        x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
+        x1,x2 = x.chunk(2, 1)
+        x1 = self.conv_layer_1(x1)
+        x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1)
+        x2 = self.conv_layer_2(x2)
+
+        return torch.cat((x1, x2), 1)
+
+class RepConvN(nn.Module):
+    """RepConv is a basic rep-style block, including training and deploy status
+    This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
+    """
+    def __init__(self, in_dim, out_dim, k=3, s=1, p=1, g=1, act_type='silu', norm_type='BN', depthwise=False):
+        super().__init__()
+        assert k == 3 and p == 1
+        self.g = g
+        self.in_dim = in_dim
+        self.out_dim = out_dim
+        self.act = get_activation(act_type)
+
+        self.bn = None
+        self.conv1 = BasicConv(in_dim, out_dim,
+                               kernel_size=k, padding=p, stride=s, group=g,
+                               act_type=None, norm_type=norm_type, depthwise=depthwise)
+        self.conv2 = BasicConv(in_dim, out_dim,
+                               kernel_size=1, padding=(p - k // 2), stride=s, group=g,
+                               act_type=None, norm_type=norm_type, depthwise=depthwise)
+
+    def forward(self, x):
+        """Forward process"""
+        if hasattr(self, 'conv'):
+            return self.forward_fuse(x)
+        else:
+            id_out = 0 if self.bn is None else self.bn(x)
+            return self.act(self.conv1(x) + self.conv2(x) + id_out)
+
+    def forward_fuse(self, x):
+        """Forward process"""
+        return self.act(self.conv(x))
+
+    def get_equivalent_kernel_bias(self):
+        kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
+        kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
+        kernelid, biasid = self._fuse_bn_tensor(self.bn)
+        return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
+
+    def _avg_to_3x3_tensor(self, avgp):
+        channels = self.in_dim
+        groups = self.g
+        kernel_size = avgp.kernel_size
+        input_dim = channels // groups
+        k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
+        k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
+        return k
+
+    def _pad_1x1_to_3x3_tensor(self, kernel1x1):
+        if kernel1x1 is None:
+            return 0
+        else:
+            return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
+
+    def _fuse_bn_tensor(self, branch):
+        if branch is None:
+            return 0, 0
+        if isinstance(branch, BasicConv):
+            kernel       = branch.conv.weight
+            running_mean = branch.norm.running_mean
+            running_var  = branch.norm.running_var
+            gamma        = branch.norm.weight
+            beta         = branch.norm.bias
+            eps          = branch.norm.eps
+        elif isinstance(branch, nn.BatchNorm2d):
+            if not hasattr(self, 'id_tensor'):
+                input_dim = self.in_dim // self.g
+                kernel_value = np.zeros((self.in_dim, input_dim, 3, 3), dtype=np.float32)
+                for i in range(self.in_dim):
+                    kernel_value[i, i % input_dim, 1, 1] = 1
+                self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
+            kernel       = self.id_tensor
+            running_mean = branch.running_mean
+            running_var  = branch.running_var
+            gamma        = branch.weight
+            beta         = branch.bias
+            eps          = branch.eps
+        std = (running_var + eps).sqrt()
+        t = (gamma / std).reshape(-1, 1, 1, 1)
+        return kernel * t, beta - running_mean * gamma / std
+
+    def fuse_convs(self):
+        if hasattr(self, 'conv'):
+            return
+        kernel, bias = self.get_equivalent_kernel_bias()
+        self.conv = nn.Conv2d(in_channels  = self.conv1.conv.in_channels,
+                              out_channels = self.conv1.conv.out_channels,
+                              kernel_size  = self.conv1.conv.kernel_size,
+                              stride       = self.conv1.conv.stride,
+                              padding      = self.conv1.conv.padding,
+                              dilation     = self.conv1.conv.dilation,
+                              groups       = self.conv1.conv.groups,
+                              bias         = True).requires_grad_(False)
+        self.conv.weight.data = kernel
+        self.conv.bias.data = bias
+        for para in self.parameters():
+            para.detach_()
+        self.__delattr__('conv1')
+        self.__delattr__('conv2')
+        if hasattr(self, 'nm'):
+            self.__delattr__('nm')
+        if hasattr(self, 'bn'):
+            self.__delattr__('bn')
+        if hasattr(self, 'id_tensor'):
+            self.__delattr__('id_tensor')
+
+class RepNBottleneck(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 shortcut=True,
+                 kernel_size=(3, 3),
+                 expansion=0.5,
+                 act_type='silu',
+                 norm_type='BN',
+                 depthwise=False
+                 ):
+        super().__init__()
+        inter_dim = round(out_dim * expansion)
+        self.conv_layer_1 = RepConvN(in_dim, inter_dim, kernel_size[0], p=kernel_size[0]//2, s=1, act_type=act_type, norm_type=norm_type)
+        self.conv_layer_2 = BasicConv(inter_dim, out_dim, kernel_size[1], padding=kernel_size[1]//2, stride=1, act_type=act_type, norm_type=norm_type)
+        self.add = shortcut and in_dim == out_dim
+
+    def forward(self, x):
+        h = self.conv_layer_2(self.conv_layer_1(x))
+        return x + h if self.add else h
+
+class RepNCSP(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 num_blocks=1,
+                 shortcut=True,
+                 expansion=0.5,
+                 act_type='silu',
+                 norm_type='BN',
+                 depthwise=False
+                 ):
+        super().__init__()
+        inter_dim = int(out_dim * expansion)
+        self.conv_layer_1 = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.conv_layer_2 = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.conv_layer_3 = BasicConv(2 * inter_dim, out_dim, kernel_size=1)
+        self.module       = nn.Sequential(*(RepNBottleneck(inter_dim,
+                                                           inter_dim,
+                                                           kernel_size = [3, 3],
+                                                           shortcut    = shortcut,
+                                                           expansion   = 1.0,
+                                                           act_type    = act_type,
+                                                           norm_type   = norm_type,
+                                                           depthwise   = depthwise)
+                                                           for _ in range(num_blocks)))
+
+    def forward(self, x):
+        x1 = self.conv_layer_1(x)
+        x2 = self.module(self.conv_layer_2(x))
+
+        return self.conv_layer_3(torch.cat([x1, x2], dim=1))
+
+class RepGElanLayer(nn.Module):
+    """YOLOv9's GELAN module"""
+    def __init__(self,
+                 in_dim     :int,
+                 inter_dims :List,
+                 out_dim    :int,
+                 num_blocks :int   = 1,
+                 shortcut   :bool  = False,
+                 act_type   :str   = 'silu',
+                 norm_type  :str   = 'BN',
+                 depthwise  :bool  = False,
+                 ) -> None:
+        super(RepGElanLayer, self).__init__()
+        # ----------- Basic parameters -----------
+        self.in_dim = in_dim
+        self.inter_dims = inter_dims
+        self.out_dim = out_dim
+
+        # ----------- Network parameters -----------
+        self.conv_layer_1  = BasicConv(in_dim, inter_dims[0], kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.elan_module_1 = nn.Sequential(
+             RepNCSP(inter_dims[0]//2,
+                     inter_dims[1],
+                     num_blocks  = num_blocks,
+                     shortcut    = shortcut,
+                     expansion   = 0.5,
+                     act_type    = act_type,
+                     norm_type   = norm_type,
+                     depthwise   = depthwise),
+            BasicConv(inter_dims[1], inter_dims[1],
+                      kernel_size=3, padding=1,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        self.elan_module_2 = nn.Sequential(
+             RepNCSP(inter_dims[1],
+                     inter_dims[1],
+                     num_blocks  = num_blocks,
+                     shortcut    = shortcut,
+                     expansion   = 0.5,
+                     act_type    = act_type,
+                     norm_type   = norm_type,
+                     depthwise   = depthwise),
+            BasicConv(inter_dims[1], inter_dims[1],
+                      kernel_size=3, padding=1,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        self.conv_layer_2 = BasicConv(inter_dims[0] + 2*self.inter_dims[1], out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+
+
+    def forward(self, x):
+        # Input proj
+        x1, x2 = torch.chunk(self.conv_layer_1(x), 2, dim=1)
+        out = list([x1, x2])
+
+        # ELAN module
+        out.append(self.elan_module_1(out[-1]))
+        out.append(self.elan_module_2(out[-1]))
+
+        # Output proj
+        out = self.conv_layer_2(torch.cat(out, dim=1))
+
+        return out
+    

+ 5 - 0
iclab/requirements.txt

@@ -0,0 +1,5 @@
+torch
+torchvision
+opencv-python
+thop
+timm

+ 332 - 0
iclab/train.py

@@ -0,0 +1,332 @@
+from copy import deepcopy
+import os
+import time
+import math
+import argparse
+import datetime
+
+# ---------------- Timm compoments ----------------
+from timm.data.mixup import Mixup
+from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
+
+# ---------------- Torch compoments ----------------
+import torch
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+# ---------------- Dataset compoments ----------------
+from data import build_dataset, build_dataloader
+
+# ---------------- Model compoments ----------------
+from models import build_model
+
+# ---------------- Utils compoments ----------------
+from utils import distributed_utils
+from utils.ema import ModelEMA
+from utils.misc import setup_seed, print_rank_0, load_model, save_model
+from utils.misc import NativeScalerWithGradNormCount as NativeScaler
+from utils.optimzer import build_optimizer
+from utils.lr_scheduler import build_lr_scheduler, LinearWarmUpLrScheduler
+from utils.com_flops_params import FLOPs_and_Params
+
+# ---------------- Training engine ----------------
+from engine import train_one_epoch, evaluate
+
+
+def parse_args():
+    parser = argparse.ArgumentParser()
+    # Input
+    parser.add_argument('--img_size', type=int, default=224,
+                        help='input image size.')    
+    parser.add_argument('--img_dim', type=int, default=3,
+                        help='3 for RGB; 1 for Gray.')    
+    parser.add_argument('--num_classes', type=int, default=1000,
+                        help='Number of the classes.')    
+    # Basic
+    parser.add_argument('--seed', type=int, default=42,
+                        help='random seed.')
+    parser.add_argument('--cuda', action='store_true', default=False,
+                        help='use cuda')
+    parser.add_argument('--batch_size', type=int, default=256,
+                        help='batch size on all GPUs')
+    parser.add_argument('--num_workers', type=int, default=4,
+                        help='number of workers')
+    parser.add_argument('--path_to_save', type=str, default='weights/',
+                        help='path to save trained model.')
+    parser.add_argument('--tfboard', action='store_true', default=False,
+                        help='use tensorboard')
+    parser.add_argument('--eval', action='store_true', default=False,
+                        help='evaluate model.')
+    # Epoch
+    parser.add_argument('--wp_epoch', type=int, default=20, 
+                        help='warmup epoch for finetune with MAE pretrained')
+    parser.add_argument('--start_epoch', type=int, default=0, 
+                        help='start epoch for finetune with MAE pretrained')
+    parser.add_argument('--max_epoch', type=int, default=300, 
+                        help='max epoch')
+    parser.add_argument('--eval_epoch', type=int, default=10, 
+                        help='max epoch')
+    # Dataset
+    parser.add_argument('--dataset', type=str, default='cifar10',
+                        help='dataset name')
+    parser.add_argument('--root', type=str, default='/mnt/share/ssd2/dataset',
+                        help='path to dataset folder')
+    # Model
+    parser.add_argument('-m', '--model', type=str, default='rtcnet_n',
+                        help='model name')
+    parser.add_argument('--resume', default=None, type=str,
+                        help='keep training')
+    parser.add_argument('--ema', action='store_true', default=False,
+                        help='use ema.')
+    parser.add_argument('--drop_path', type=float, default=0.1,
+                        help='drop_path')
+    # Optimizer
+    parser.add_argument('-opt', '--optimizer', type=str, default='adamw',
+                        help='sgd, adam')
+    parser.add_argument('-lrs', '--lr_scheduler', type=str, default='step',
+                        help='cosine, step')
+    parser.add_argument('-mt', '--momentum', type=float, default=0.9,
+                        help='weight decay')
+    parser.add_argument('-wd', '--weight_decay', type=float, default=0.05,
+                        help='weight decay')
+    parser.add_argument('--batch_base', type=int, default=256,
+                        help='gradient accumulation')
+    parser.add_argument('--base_lr', type=float, default=1e-3,
+                        help='learning rate for training model')
+    parser.add_argument('--min_lr', type=float, default=1e-6,
+                        help='the final lr')
+    parser.add_argument('--grad_accumulate', type=int, default=1,
+                        help='gradient accumulation')
+    parser.add_argument('--max_grad_norm', type=float, default=None,
+                        help='Clip gradient norm (default: None, no clipping)')
+    # Augmentation parameters
+    parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
+                        help='Color jitter factor (enabled only when not using Auto/RandAug)')
+    parser.add_argument('--aa', type=str, default=None, metavar='NAME',
+                        help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
+    parser.add_argument('--smoothing', type=float, default=0.1,
+                        help='Label smoothing (default: 0.1)')
+    # Random Erase params
+    parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
+                        help='Random erase prob (default: 0.25)')
+    parser.add_argument('--remode', type=str, default='pixel',
+                        help='Random erase mode (default: "pixel")')
+    parser.add_argument('--recount', type=int, default=1,
+                        help='Random erase count (default: 1)')
+    parser.add_argument('--resplit', action='store_true', default=False,
+                        help='Do not random erase first (clean) augmentation split')
+    # Mixup params
+    parser.add_argument('--mixup', type=float, default=0,
+                        help='mixup alpha, mixup enabled if > 0.')
+    parser.add_argument('--cutmix', type=float, default=0,
+                        help='cutmix alpha, cutmix enabled if > 0.')
+    parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
+                        help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
+    parser.add_argument('--mixup_prob', type=float, default=1.0,
+                        help='Probability of performing mixup or cutmix when either/both is enabled')
+    parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
+                        help='Probability of switching to cutmix when both mixup and cutmix enabled')
+    parser.add_argument('--mixup_mode', type=str, default='batch',
+                        help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
+    # DDP
+    parser.add_argument('-dist', '--distributed', action='store_true', default=False,
+                        help='distributed training')
+    parser.add_argument('--dist_url', default='env://', 
+                        help='url used to set up distributed training')
+    parser.add_argument('--world_size', default=1, type=int,
+                        help='number of distributed processes')
+    parser.add_argument('--sybn', action='store_true', default=False, 
+                        help='use sybn.')
+    parser.add_argument('--local_rank', default=-1, type=int,
+                        help='the number of local rank.')
+
+    return parser.parse_args()
+
+    
+def main():
+    args = parse_args()
+    # set random seed
+    setup_seed(args.seed)
+
+    # Path to save model
+    path_to_save = os.path.join(args.path_to_save, args.dataset, args.model)
+    os.makedirs(path_to_save, exist_ok=True)
+    args.output_dir = path_to_save
+    
+    # ------------------------- Build DDP environment -------------------------
+    ## LOCAL_RANK is the global GPU number tag, the value range is [0, world_size - 1].
+    ## LOCAL_PROCESS_RANK is the number of the GPU of each machine, not global.
+    local_rank = local_process_rank = -1
+    if args.distributed:
+        distributed_utils.init_distributed_mode(args)
+        print("git:\n  {}\n".format(distributed_utils.get_sha()))
+        try:
+            # Multiple Mechine & Multiple GPUs (world size > 8)
+            local_rank = torch.distributed.get_rank()
+            local_process_rank = int(os.getenv('LOCAL_PROCESS_RANK', '0'))
+        except:
+            # Single Mechine & Multiple GPUs (world size <= 8)
+            local_rank = local_process_rank = torch.distributed.get_rank()
+    print_rank_0(args)
+    args.world_size = distributed_utils.get_world_size()
+    print('World size: {}'.format(distributed_utils.get_world_size()))
+    print("LOCAL RANK: ", local_rank)
+    print("LOCAL_PROCESS_RANL: ", local_process_rank)
+
+    # ------------------------- Build CUDA -------------------------
+    if args.cuda:
+        if torch.cuda.is_available():
+            cudnn.benchmark = True
+            device = torch.device("cuda")
+        else:
+            print('There is no available GPU.')
+            args.cuda = False
+            device = torch.device("cpu")
+    else:
+        device = torch.device("cpu")
+
+    # ------------------------- Build Tensorboard -------------------------
+    tblogger = None
+    if local_rank <= 0 and args.tfboard:
+        print('use tensorboard')
+        from torch.utils.tensorboard import SummaryWriter
+        time_stamp = time.strftime('%Y-%m-%d_%H:%M:%S',time.localtime(time.time()))
+        log_path = os.path.join('log/', args.dataset, time_stamp)
+        os.makedirs(log_path, exist_ok=True)
+        tblogger = SummaryWriter(log_path)
+
+    # ------------------------- Build Dataset -------------------------
+    train_dataset = build_dataset(args, is_train=True)
+    val_dataset   = build_dataset(args, is_train=False)
+
+    # ------------------------- Build Dataloader -------------------------
+    train_dataloader = build_dataloader(args, train_dataset, is_train=True)
+    val_dataloader   = build_dataloader(args, val_dataset,   is_train=False)
+
+    print('=================== Dataset Information ===================')
+    print("Dataset: ", args.dataset)
+    print('- train dataset size : ', len(train_dataset))
+    print('- val dataset size   : ', len(val_dataset))
+
+    # ------------------------- Build Model -------------------------
+    model = build_model(args)
+    model.train().to(device)
+    print(model)
+    if local_rank <= 0:
+        model_copy = deepcopy(model)
+        model_copy.eval()
+        FLOPs_and_Params(model_copy, args.img_size, args.img_dim, device)
+        model_copy.train()
+        del model_copy
+    if args.distributed:
+        # wait for all processes to synchronize
+        dist.barrier()
+
+    # ------------------------- Build DDP Model -------------------------
+    model_without_ddp = model
+    if args.distributed:
+        model = DDP(model, device_ids=[args.gpu])
+        if args.sybn:
+            print('use SyncBatchNorm ...')
+            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+        model_without_ddp = model.module
+
+    # ------------------------- Mixup augmentation config -------------------------
+    mixup_fn = None
+    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
+    if mixup_active:
+        print_rank_0("Mixup is activated!", local_rank)
+        mixup_fn = Mixup(mixup_alpha     = args.mixup,
+                         cutmix_alpha    = args.cutmix,
+                         cutmix_minmax   = args.cutmix_minmax,
+                         prob            = args.mixup_prob,
+                         switch_prob     = args.mixup_switch_prob,
+                         mode            = args.mixup_mode,
+                         label_smoothing = args.smoothing,
+                         num_classes     = args.num_classes)
+
+
+    # ------------------------- Build Optimzier -------------------------
+    optimizer = build_optimizer(args, model_without_ddp)
+    loss_scaler = NativeScaler()
+
+    # ------------------------- Build Lr Scheduler -------------------------
+    lr_scheduler_warmup = LinearWarmUpLrScheduler(args.base_lr, wp_iter=args.wp_epoch * len(train_dataloader))
+    lr_scheduler = build_lr_scheduler(args, optimizer)
+
+    # ------------------------- Build Criterion -------------------------
+    if mixup_fn is not None:
+        # smoothing is handled with mixup label transform
+        criterion = SoftTargetCrossEntropy()
+    elif args.smoothing > 0.:
+        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
+    else:
+        criterion = torch.nn.CrossEntropyLoss()
+    load_model(args=args, model_without_ddp=model_without_ddp,
+               optimizer=optimizer, lr_scheduler=lr_scheduler, loss_scaler=loss_scaler)
+
+    # ------------------------- Build Model-EMA -------------------------
+    if args.ema:
+        print("Build model ema for {}".format(args.model))
+        updates = args.start_epoch * len(train_dataloader) // args.grad_accumulate
+        print("Initialial updates of ModelEMA: {}".format(updates))
+        model_ema = ModelEMA(model_without_ddp, ema_decay=0.999, ema_tau=2000., updates=updates)
+    else:
+        model_ema = None
+
+    # ------------------------- Eval before Train Pipeline -------------------------
+    if args.eval:
+        print('evaluating ...')
+        test_stats = evaluate(val_dataloader, model_without_ddp, device, local_rank)
+        print('Eval Results: [loss: %.2f][acc1: %.2f][acc5 : %.2f]' %
+                (test_stats['loss'], test_stats['acc1'], test_stats['acc5']), flush=True)
+        return
+
+    # ------------------------- Training Pipeline -------------------------
+    start_time = time.time()
+    max_accuracy = -1.0
+    print_rank_0("=============== Start training for {} epochs ===============".format(args.max_epoch), local_rank)
+    for epoch in range(args.start_epoch, args.max_epoch):
+        if args.distributed:
+            train_dataloader.batch_sampler.sampler.set_epoch(epoch)
+
+        # train one epoch
+        train_one_epoch(args, device, model, model_ema, train_dataloader, optimizer, epoch,
+                        lr_scheduler_warmup, loss_scaler, criterion, local_rank, tblogger, mixup_fn)
+
+        # LR scheduler
+        if (epoch + 1) > args.wp_epoch:
+            lr_scheduler.step()
+
+        # Evaluate
+        if local_rank <= 0:
+            model_eval = model_ema.ema if model_ema is not None else model_without_ddp
+            if (epoch % args.eval_epoch) == 0 or (epoch + 1 == args.max_epoch):
+                print_rank_0("Evaluating ...")
+                test_stats = evaluate(val_dataloader, model_eval, device, local_rank)
+                print_rank_0(f"Accuracy of the network on the {len(val_dataset)} test images: {test_stats['acc1']:.1f}%", local_rank)
+                max_accuracy = max(max_accuracy, test_stats["acc1"])
+                print_rank_0(f'Max accuracy: {max_accuracy:.2f}%', local_rank)
+
+                # Save model
+                print('- saving the model after {} epochs ...'.format(epoch))
+                save_model(args=args, model=model_eval, model_without_ddp=model_eval,
+                           optimizer=optimizer, lr_scheduler=lr_scheduler, loss_scaler=loss_scaler, epoch=epoch, acc1=max_accuracy)
+        if args.distributed:
+            dist.barrier()
+
+        if tblogger is not None:
+            tblogger.add_scalar('perf/test_acc1', test_stats['acc1'], epoch)
+            tblogger.add_scalar('perf/test_acc5', test_stats['acc5'], epoch)
+            tblogger.add_scalar('perf/test_loss', test_stats['loss'], epoch)
+        if args.distributed:
+            dist.barrier()
+
+    total_time = time.time() - start_time
+    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+    print('Training time {}'.format(total_time_str))
+
+
+if __name__ == "__main__":
+    main()

+ 164 - 0
iclab/train.sh

@@ -0,0 +1,164 @@
+# ------------------- Args setting -------------------
+MODEL=$1
+DATASET=$2
+DATASET_ROOT=$3
+WORLD_SIZE=$4
+MASTER_PORT=$5
+RESUME=$6
+
+# ------------------- Training setting -------------------
+## Epoch
+BATCH_SIZE=128
+GRAD_ACCUMULATE=32
+WP_EPOCH=10
+MAX_EPOCH=100
+EVAL_EPOCH=5
+DROP_PATH=0.1
+## Scheduler
+OPTIMIZER="adamw"
+LRSCHEDULER="cosine"
+BASE_LR=1e-3         # 0.1 for SGD; 0.001 for AdamW
+MIN_LR=1e-6
+BATCH_BASE=1024      # 256 for SGD; 1024 for AdamW
+MOMENTUM=0.9
+WEIGHT_DECAY=0.05    # 0.0001 for SGD; 0.05 for AdamW
+
+# ------------------- Dataset config -------------------
+if [[ $DATASET == "mnist" ]]; then
+    IMG_SIZE=28
+    NUM_CLASSES=10
+elif [[ $DATASET == "cifar10" ]]; then
+    IMG_SIZE=32
+    NUM_CLASSES=10
+elif [[ $DATASET == "cifar100" ]]; then
+    IMG_SIZE=32
+    NUM_CLASSES=100
+elif [[ $DATASET == "imagenet_1k" || $DATASET == "imagenet_22k" ]]; then
+    IMG_SIZE=224
+    NUM_CLASSES=1000
+elif [[ $DATASET == "custom" ]]; then
+    IMG_SIZE=224
+    NUM_CLASSES=2
+else
+    echo "Unknown dataset!!"
+    exit 1
+fi
+
+
+# ------------------- Training pipeline -------------------
+if [ $WORLD_SIZE == 1 ]; then
+    python train.py \
+            --cuda \
+            --root ${DATASET_ROOT} \
+            --dataset ${DATASET} \
+            --model ${MODEL} \
+            --resume ${RESUME} \
+            --batch_size ${BATCH_SIZE} \
+            --batch_base ${BATCH_BASE} \
+            --grad_accumulate ${GRAD_ACCUMULATE} \
+            --img_size ${IMG_SIZE} \
+            --drop_path ${DROP_PATH} \
+            --max_epoch ${MAX_EPOCH} \
+            --wp_epoch ${WP_EPOCH} \
+            --eval_epoch ${EVAL_EPOCH} \
+            --optimizer ${OPTIMIZER} \
+            --lr_scheduler ${LRSCHEDULER} \
+            --base_lr ${BASE_LR} \
+            --min_lr ${MIN_LR} \
+            --momentum ${MOMENTUM} \
+            --weight_decay ${WEIGHT_DECAY} \
+            --color_jitter 0.0 \
+            --reprob 0.0 \
+            --mixup 0.0 \
+            --cutmix 0.0
+
+elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
+    python -m torch.distributed.run --nproc_per_node=${WORLD_SIZE} --master_port ${MASTER_PORT} train.py \
+            --cuda \
+            --distributed \
+            --root ${DATASET_ROOT} \
+            --dataset ${DATASET} \
+            --model ${MODEL} \
+            --resume ${RESUME} \
+            --batch_size ${BATCH_SIZE} \
+            --batch_base ${BATCH_BASE} \
+            --grad_accumulate ${GRAD_ACCUMULATE} \
+            --img_size ${IMG_SIZE} \
+            --drop_path ${DROP_PATH} \
+            --max_epoch ${MAX_EPOCH} \
+            --wp_epoch ${WP_EPOCH} \
+            --eval_epoch ${EVAL_EPOCH} \
+            --optimizer ${OPTIMIZER} \
+            --lr_scheduler ${LRSCHEDULER} \
+            --base_lr ${BASE_LR} \
+            --min_lr ${MIN_LR} \
+            --momentum ${MOMENTUM} \
+            --weight_decay ${WEIGHT_DECAY} \
+            --sybn \
+            --color_jitter 0.0 \
+            --reprob 0.0 \
+            --mixup 0.0 \
+            --cutmix 0.0
+else
+    echo "The WORLD_SIZE is set to a value greater than 8, indicating the use of multi-machine \
+          multi-card training mode, which is currently unsupported."
+    exit 1
+fi
+
+
+# # ------------------- Training pipeline with strong augmentations -------------------
+# if [ $WORLD_SIZE == 1 ]; then
+#     python train.py \
+#             --cuda \
+#             --root ${DATASET_ROOT} \
+#             --dataset ${DATASET} \
+#             --model ${MODEL} \
+#             --resume ${RESUME} \
+#             --batch_size ${BATCH_SIZE} \
+#             --batch_base ${BATCH_BASE} \
+#             --grad_accumulate ${GRAD_ACCUMULATE} \
+#             --img_size ${IMG_SIZE} \
+#             --drop_path ${DROP_PATH} \
+#             --max_epoch ${MAX_EPOCH} \
+#             --wp_epoch ${WP_EPOCH} \
+#             --eval_epoch ${EVAL_EPOCH} \
+#             --optimizer ${OPTIMIZER} \
+#             --lr_scheduler ${LRSCHEDULER} \
+#             --base_lr ${BASE_LR} \
+#             --min_lr ${MIN_LR} \
+#             --weight_decay ${WEIGHT_DECAY} \
+#             --aa "rand-m9-mstd0.5-inc1" \
+#             --reprob 0.25 \
+#             --mixup 0.8 \
+#             --cutmix 1.0
+# elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
+#     python -m torch.distributed.run --nproc_per_node=${WORLD_SIZE} --master_port ${MASTER_PORT} train.py \
+#             --cuda \
+#             --distributed \
+#             --root ${DATASET_ROOT} \
+#             --dataset ${DATASET} \
+#             --model ${MODEL} \
+#             --resume ${RESUME} \
+#             --batch_size ${BATCH_SIZE} \
+#             --batch_base ${BATCH_BASE} \
+#             --grad_accumulate ${GRAD_ACCUMULATE} \
+#             --img_size ${IMG_SIZE} \
+#             --drop_path ${DROP_PATH} \
+#             --max_epoch ${MAX_EPOCH} \
+#             --wp_epoch ${WP_EPOCH} \
+#             --eval_epoch ${EVAL_EPOCH} \
+#             --optimizer ${OPTIMIZER} \
+#             --lr_scheduler ${LRSCHEDULER} \
+#             --base_lr ${BASE_LR} \
+#             --min_lr ${MIN_LR} \
+#             --weight_decay ${WEIGHT_DECAY} \
+#             --sybn \
+#             --aa "rand-m9-mstd0.5-inc1" \
+#             --reprob 0.25 \
+#             --mixup 0.8 \
+#             --cutmix 1.0
+# else
+#     echo "The WORLD_SIZE is set to a value greater than 8, indicating the use of multi-machine \
+#           multi-card training mode, which is currently unsupported."
+#     exit 1
+# fi

+ 0 - 0
iclab/utils/__init__.py


+ 18 - 0
iclab/utils/com_flops_params.py

@@ -0,0 +1,18 @@
+import torch
+from thop import profile
+
+
+def FLOPs_and_Params(model, size, img_dim, device):
+    x = torch.randn(1, img_dim, size, size).to(device)
+    model.eval()
+
+    flops, params = profile(model, inputs=(x, ))
+    print('=================== FLOPs & Params ===================')
+    print('- GFLOPs : ', flops / 1e9 * 2)
+    print('- Params : ', params / 1e6, ' M')
+
+    model.train()
+
+
+if __name__ == "__main__":
+    pass

+ 165 - 0
iclab/utils/distributed_utils.py

@@ -0,0 +1,165 @@
+# from github: https://github.com/ruinmessi/ASFF/blob/master/utils/distributed_util.py
+
+import torch
+import torch.distributed as dist
+import os
+import subprocess
+import pickle
+
+
+def all_gather(data):
+    """
+    Run all_gather on arbitrary picklable data (not necessarily tensors)
+    Args:
+        data: any picklable object
+    Returns:
+        list[data]: list of data gathered from each rank
+    """
+    world_size = get_world_size()
+    if world_size == 1:
+        return [data]
+
+    # serialized to a Tensor
+    buffer = pickle.dumps(data)
+    storage = torch.ByteStorage.from_buffer(buffer)
+    tensor = torch.ByteTensor(storage).to("cuda")
+
+    # obtain Tensor size of each rank
+    local_size = torch.tensor([tensor.numel()], device="cuda")
+    size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
+    dist.all_gather(size_list, local_size)
+    size_list = [int(size.item()) for size in size_list]
+    max_size = max(size_list)
+
+    # receiving Tensor from all ranks
+    # we pad the tensor because torch all_gather does not support
+    # gathering tensors of different shapes
+    tensor_list = []
+    for _ in size_list:
+        tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
+    if local_size != max_size:
+        padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
+        tensor = torch.cat((tensor, padding), dim=0)
+    dist.all_gather(tensor_list, tensor)
+
+    data_list = []
+    for size, tensor in zip(size_list, tensor_list):
+        buffer = tensor.cpu().numpy().tobytes()[:size]
+        data_list.append(pickle.loads(buffer))
+
+    return data_list
+
+
+def reduce_dict(input_dict, average=True):
+    """
+    Args:
+        input_dict (dict): all the values will be reduced
+        average (bool): whether to do average or sum
+    Reduce the values in the dictionary from all processes so that all processes
+    have the averaged results. Returns a dict with the same fields as
+    input_dict, after reduction.
+    """
+    world_size = get_world_size()
+    if world_size < 2:
+        return input_dict
+    with torch.no_grad():
+        names = []
+        values = []
+        # sort the keys so that they are consistent across processes
+        for k in sorted(input_dict.keys()):
+            names.append(k)
+            values.append(input_dict[k])
+        values = torch.stack(values, dim=0)
+        dist.all_reduce(values)
+        if average:
+            values /= world_size
+        reduced_dict = {k: v for k, v in zip(names, values)}
+    return reduced_dict
+
+
+def get_sha():
+    cwd = os.path.dirname(os.path.abspath(__file__))
+
+    def _run(command):
+        return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
+    sha = 'N/A'
+    diff = "clean"
+    branch = 'N/A'
+    try:
+        sha = _run(['git', 'rev-parse', 'HEAD'])
+        subprocess.check_output(['git', 'diff'], cwd=cwd)
+        diff = _run(['git', 'diff-index', 'HEAD'])
+        diff = "has uncommited changes" if diff else "clean"
+        branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
+    except Exception:
+        pass
+    message = f"sha: {sha}, status: {diff}, branch: {branch}"
+    return message
+
+
+def setup_for_distributed(is_master):
+    """
+    This function disables printing when not in master process
+    """
+    import builtins as __builtin__
+    builtin_print = __builtin__.print
+
+    def print(*args, **kwargs):
+        force = kwargs.pop('force', False)
+        if is_master or force:
+            builtin_print(*args, **kwargs)
+
+    __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+    if not dist.is_available():
+        return False
+    if not dist.is_initialized():
+        return False
+    return True
+
+
+def get_world_size():
+    if not is_dist_avail_and_initialized():
+        return 1
+    return dist.get_world_size()
+
+
+def get_rank():
+    if not is_dist_avail_and_initialized():
+        return 0
+    return dist.get_rank()
+
+def is_main_process():
+    return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+    if is_main_process():
+        torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+        args.rank = int(os.environ["RANK"])
+        args.world_size = int(os.environ['WORLD_SIZE'])
+        args.gpu = int(os.environ['LOCAL_RANK'])
+    elif 'SLURM_PROCID' in os.environ:
+        args.rank = int(os.environ['SLURM_PROCID'])
+        args.gpu = args.rank % torch.cuda.device_count()
+    else:
+        print('Not using distributed mode')
+        args.distributed = False
+        return
+
+    args.distributed = True
+
+    torch.cuda.set_device(args.gpu)
+    args.dist_backend = 'nccl'
+    print('| distributed init (rank {}): {}'.format(
+        args.rank, args.dist_url), flush=True)
+    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+                                         world_size=args.world_size, rank=args.rank)
+    torch.distributed.barrier()
+    setup_for_distributed(args.rank == 0)

+ 52 - 0
iclab/utils/ema.py

@@ -0,0 +1,52 @@
+from copy import deepcopy
+import math
+import torch
+import torch.nn as nn
+
+
+# ---------------------------- Model tools ----------------------------
+def is_parallel(model):
+    # Returns True if model is of type DP or DDP
+    return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
+
+## Model EMA
+class ModelEMA(object):
+    def __init__(self, model, ema_decay=0.9999, ema_tau=2000, updates=0):
+        # Create EMA
+        self.ema = deepcopy(self.de_parallel(model)).eval()  # FP32 EMA
+        self.updates = updates  # number of EMA updates
+        self.decay = lambda x: ema_decay * (1 - math.exp(-x / ema_tau))  # decay exponential ramp (to help early epochs)
+        for p in self.ema.parameters():
+            p.requires_grad_(False)
+
+    def is_parallel(self, model):
+        # Returns True if model is of type DP or DDP
+        return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
+
+    def de_parallel(self, model):
+        # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
+        return model.module if self.is_parallel(model) else model
+
+    def copy_attr(self, a, b, include=(), exclude=()):
+        # Copy attributes from b to a, options to only include [...] and to exclude [...]
+        for k, v in b.__dict__.items():
+            if (len(include) and k not in include) or k.startswith('_') or k in exclude:
+                continue
+            else:
+                setattr(a, k, v)
+
+    def update(self, model):
+        # Update EMA parameters
+        self.updates += 1
+        d = self.decay(self.updates)
+
+        msd = self.de_parallel(model).state_dict()  # model state_dict
+        for k, v in self.ema.state_dict().items():
+            if v.dtype.is_floating_point:  # true for FP16 and FP32
+                v *= d
+                v += (1 - d) * msd[k].detach()
+        # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
+
+    def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
+        # Update EMA attributes
+        self.copy_attr(self.ema, model, include, exclude)

+ 39 - 0
iclab/utils/lr_scheduler.py

@@ -0,0 +1,39 @@
+import torch
+
+
+# Basic Warmup Scheduler
+class LinearWarmUpLrScheduler(object):
+    def __init__(self, base_lr=0.01, wp_iter=500, warmup_factor=0.00066667):
+        self.base_lr = base_lr
+        self.wp_iter = wp_iter
+        self.warmup_factor = warmup_factor
+
+    def set_lr(self, optimizer, cur_lr):
+        for param_group in optimizer.param_groups:
+            init_lr = param_group['initial_lr']
+            ratio = init_lr / self.base_lr
+            param_group['lr'] = cur_lr * ratio
+
+    def __call__(self, iter, optimizer):
+        # warmup
+        assert iter < self.wp_iter
+        alpha = iter / self.wp_iter
+        warmup_factor = self.warmup_factor * (1 - alpha) + alpha
+        tmp_lr = self.base_lr * warmup_factor
+        self.set_lr(optimizer, tmp_lr)
+
+
+def build_lr_scheduler(args, optimizer):
+    if args.lr_scheduler == "step":
+        lr_step = [args.max_epoch // 3, args.max_epoch // 3 * 2]
+        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_step, gamma=0.1)
+    elif args.lr_scheduler == "cosine":
+        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch - args.wp_epoch - 1, eta_min=args.min_lr)
+    else:
+        raise NotImplementedError("Unknown lr scheduler: {}".format(args.lr_scheduler))
+    
+    print("=================== LR Scheduler information ===================")
+    print("LR Scheduler: ", args.lr_scheduler)
+
+    return scheduler
+        

+ 291 - 0
iclab/utils/misc.py

@@ -0,0 +1,291 @@
+import time
+import numpy as np
+import random
+import datetime
+from collections import defaultdict, deque
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+
+from .distributed_utils import get_world_size, is_main_process, is_dist_avail_and_initialized
+
+
+# ---------------------- Common functions ----------------------
+def all_reduce_mean(x):
+    world_size = get_world_size()
+    if world_size > 1:
+        x_reduce = torch.tensor(x).cuda()
+        dist.all_reduce(x_reduce)
+        x_reduce /= world_size
+        return x_reduce.item()
+    else:
+        return x
+
+def print_rank_0(msg, rank=None):
+    if rank is not None and rank <= 0:
+        print(msg)
+    elif is_main_process():
+        print(msg)
+
+def setup_seed(seed=42):
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)
+    np.random.seed(seed)
+    random.seed(seed)
+    torch.backends.cudnn.deterministic = True
+
+def is_parallel(model):
+    # Returns True if model is of type DP or DDP
+    return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
+
+def accuracy(output, target, topk=(1,)):
+    """Computes the accuracy over the k top predictions for the specified values of k"""
+    with torch.no_grad():
+        maxk = max(topk)
+        batch_size = target.size(0)
+
+        _, pred = output.topk(maxk, 1, True, True)
+        pred = pred.t()
+        correct = pred.eq(target.reshape(1, -1).expand_as(pred))
+
+        res = []
+        for k in topk:
+            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
+            res.append(correct_k.mul_(100.0 / batch_size))
+        return res
+
+class SmoothedValue(object):
+    """Track a series of values and provide access to smoothed values over a
+    window or the global series average.
+    """
+
+    def __init__(self, window_size=20, fmt=None):
+        if fmt is None:
+            fmt = "{median:.4f} ({global_avg:.4f})"
+        self.deque = deque(maxlen=window_size)
+        self.total = 0.0
+        self.count = 0
+        self.fmt = fmt
+
+    def update(self, value, n=1):
+        self.deque.append(value)
+        self.count += n
+        self.total += value * n
+
+    def synchronize_between_processes(self):
+        """
+        Warning: does not synchronize the deque!
+        """
+        if not is_dist_avail_and_initialized():
+            return
+        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+        dist.barrier()
+        dist.all_reduce(t)
+        t = t.tolist()
+        self.count = int(t[0])
+        self.total = t[1]
+
+    @property
+    def median(self):
+        d = torch.tensor(list(self.deque))
+        return d.median().item()
+
+    @property
+    def avg(self):
+        d = torch.tensor(list(self.deque), dtype=torch.float32)
+        return d.mean().item()
+
+    @property
+    def global_avg(self):
+        return self.total / self.count
+
+    @property
+    def max(self):
+        return max(self.deque)
+
+    @property
+    def value(self):
+        return self.deque[-1]
+
+    def __str__(self):
+        return self.fmt.format(
+            median=self.median,
+            avg=self.avg,
+            global_avg=self.global_avg,
+            max=self.max,
+            value=self.value)
+
+class MetricLogger(object):
+    def __init__(self, delimiter="\t"):
+        self.meters = defaultdict(SmoothedValue)
+        self.delimiter = delimiter
+
+    def update(self, **kwargs):
+        for k, v in kwargs.items():
+            if v is None:
+                continue
+            if isinstance(v, torch.Tensor):
+                v = v.item()
+            assert isinstance(v, (float, int))
+            self.meters[k].update(v)
+
+    def __getattr__(self, attr):
+        if attr in self.meters:
+            return self.meters[attr]
+        if attr in self.__dict__:
+            return self.__dict__[attr]
+        raise AttributeError("'{}' object has no attribute '{}'".format(
+            type(self).__name__, attr))
+
+    def __str__(self):
+        loss_str = []
+        for name, meter in self.meters.items():
+            loss_str.append(
+                "{}: {}".format(name, str(meter))
+            )
+        return self.delimiter.join(loss_str)
+
+    def synchronize_between_processes(self):
+        for meter in self.meters.values():
+            meter.synchronize_between_processes()
+
+    def add_meter(self, name, meter):
+        self.meters[name] = meter
+
+    def log_every(self, iterable, print_freq, header=None):
+        i = 0
+        if not header:
+            header = ''
+        start_time = time.time()
+        end = time.time()
+        iter_time = SmoothedValue(fmt='{avg:.4f}')
+        data_time = SmoothedValue(fmt='{avg:.4f}')
+        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+        log_msg = [
+            header,
+            '[{0' + space_fmt + '}/{1}]',
+            'eta: {eta}',
+            '{meters}',
+            'time: {time}',
+            'data: {data}'
+        ]
+        if torch.cuda.is_available():
+            log_msg.append('max mem: {memory:.0f}')
+        log_msg = self.delimiter.join(log_msg)
+        MB = 1024.0 * 1024.0
+        for obj in iterable:
+            data_time.update(time.time() - end)
+            yield obj
+            iter_time.update(time.time() - end)
+            if i % print_freq == 0 or i == len(iterable) - 1:
+                eta_seconds = iter_time.global_avg * (len(iterable) - i)
+                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+                if torch.cuda.is_available():
+                    print(log_msg.format(
+                        i, len(iterable), eta=eta_string,
+                        meters=str(self),
+                        time=str(iter_time), data=str(data_time),
+                        memory=torch.cuda.max_memory_allocated() / MB))
+                else:
+                    print(log_msg.format(
+                        i, len(iterable), eta=eta_string,
+                        meters=str(self),
+                        time=str(iter_time), data=str(data_time)))
+            i += 1
+            end = time.time()
+        total_time = time.time() - start_time
+        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+        print('{} Total time: {} ({:.4f} s / it)'.format(
+            header, total_time_str, total_time / len(iterable)))
+
+
+# ---------------------- Optimize functions ----------------------
+def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
+    if isinstance(parameters, torch.Tensor):
+        parameters = [parameters]
+    parameters = [p for p in parameters if p.grad is not None]
+    norm_type = float(norm_type)
+    if len(parameters) == 0:
+        return torch.tensor(0.)
+    device = parameters[0].grad.device
+    total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device)
+                                        for p in parameters]),
+                            norm_type)
+
+    return total_norm
+
+class NativeScalerWithGradNormCount:
+    state_dict_key = "amp_scaler"
+
+    def __init__(self):
+        self._scaler = torch.cuda.amp.GradScaler()
+
+    def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
+        self._scaler.scale(loss).backward()
+        if update_grad:
+            if clip_grad is not None:
+                assert parameters is not None
+                self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
+                norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
+            else:
+                self._scaler.unscale_(optimizer)
+                norm = get_grad_norm_(parameters)
+            self._scaler.step(optimizer)
+            self._scaler.update()
+        else:
+            norm = None
+        return norm
+
+    def state_dict(self):
+        return self._scaler.state_dict()
+
+    def load_state_dict(self, state_dict):
+        self._scaler.load_state_dict(state_dict)
+
+
+# ---------------------- Model functions ----------------------
+def load_model(args, model_without_ddp, optimizer, lr_scheduler, loss_scaler):
+    if args.resume and args.resume.lower() != 'none':
+        print("=================== Load checkpoint ===================")
+        if args.resume.startswith('https'):
+            checkpoint = torch.hub.load_state_dict_from_url(
+                args.resume, map_location='cpu', check_hash=True)
+        else:
+            checkpoint = torch.load(args.resume, map_location='cpu')
+        model_without_ddp.load_state_dict(checkpoint['model'])
+        print("Resume checkpoint %s" % args.resume)
+        
+        if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
+            print('- Load optimizer from the checkpoint. ')
+            optimizer.load_state_dict(checkpoint['optimizer'])
+            args.start_epoch = checkpoint['epoch'] + 1
+            if 'scaler' in checkpoint:
+                loss_scaler.load_state_dict(checkpoint['scaler'])
+
+        if 'lr_scheduler' in checkpoint:
+            print('- Load lr scheduler from the checkpoint. ')
+            lr_scheduler.load_state_dict(checkpoint.pop("lr_scheduler"))
+
+def save_model(args, epoch, model, model_without_ddp, optimizer, lr_scheduler, loss_scaler, acc1=None):
+    output_dir = Path(args.output_dir)
+    epoch_name = str(epoch)
+    if loss_scaler is not None:
+        if acc1 is not None:
+            checkpoint_paths = [output_dir / ('checkpoint-{}-Acc1-{:.2f}.pth'.format(epoch_name, acc1))]
+        else:
+            checkpoint_paths = [output_dir / ('checkpoint-{}.pth'.format(epoch_name))]
+        for checkpoint_path in checkpoint_paths:
+            to_save = {
+                'model': model_without_ddp.state_dict(),
+                'optimizer': optimizer.state_dict(),
+                'lr_scheduler': lr_scheduler.state_dict(),
+                'epoch': epoch,
+                'scaler': loss_scaler.state_dict(),
+                'args': args,
+            }
+            torch.save(to_save, checkpoint_path)
+    else:
+        client_state = {'epoch': epoch}
+        model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)

+ 24 - 0
iclab/utils/optimzer.py

@@ -0,0 +1,24 @@
+import torch
+
+
+def build_optimizer(args, model):
+    ## learning rate
+    if args.optimizer == "adamw":
+        args.base_lr = args.base_lr / args.batch_base * args.batch_size * args.grad_accumulate    # auto scale lr
+        ## optimizer
+        optimizer = torch.optim.AdamW(model.parameters(), lr=args.base_lr, weight_decay=args.weight_decay)
+    elif args.optimizer == "sgd":
+        args.base_lr = args.base_lr / args.batch_base * args.batch_size * args.grad_accumulate    # auto scale lr
+        ## optimizer
+        optimizer = torch.optim.SGD(model.parameters(), lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
+    else:
+        raise NotImplementedError("Unknown optimizer: {}".format(args.optimizer))
+
+    print("=================== Optimizer information ===================")
+    print("Optimizer: ", args.optimizer)
+    print("- momoentum: ", args.momentum)
+    print("- weight decay: ", args.weight_decay)
+    print('- base lr: ', args.base_lr)
+    print('- min  lr: ', args.min_lr)
+
+    return optimizer