浏览代码

all in yolo

yjh0410 1 年之前
父节点
当前提交
0ad3b62fc7
共有 100 个文件被更改,包括 3408 次插入4138 次删除
  1. 0 0
      classify/.gitignore
  2. 0 0
      classify/README.md
  3. 0 0
      classify/data/__init__.py
  4. 0 0
      classify/data/cifar.py
  5. 0 0
      classify/data/custom.py
  6. 0 0
      classify/data/mnist.py
  7. 0 0
      classify/engine.py
  8. 0 0
      classify/main.py
  9. 0 3
      classify/models/__init__.py
  10. 0 0
      classify/models/convnet/build.py
  11. 0 0
      classify/models/convnet/convnet.py
  12. 0 0
      classify/models/convnet/modules.py
  13. 0 0
      classify/models/mlp/build.py
  14. 0 0
      classify/models/mlp/mlp.py
  15. 0 0
      classify/models/mlp/modules.py
  16. 0 0
      classify/models/vit/build.py
  17. 0 0
      classify/models/vit/modules.py
  18. 0 0
      classify/models/vit/vit.py
  19. 0 0
      classify/requirements.txt
  20. 0 0
      classify/utils/__init__.py
  21. 0 0
      classify/utils/lr_scheduler.py
  22. 0 0
      classify/utils/misc.py
  23. 0 0
      classify/utils/optimzer.py
  24. 0 21
      image_classification/models/resnet/build.py
  25. 0 164
      image_classification/models/resnet/modules.py
  26. 0 110
      image_classification/models/resnet/resnet.py
  27. 0 12
      masked_image_modeling/.gitignore
  28. 0 73
      masked_image_modeling/README.md
  29. 0 38
      masked_image_modeling/data/__init__.py
  30. 0 65
      masked_image_modeling/data/cifar.py
  31. 0 87
      masked_image_modeling/data/custom.py
  32. 0 107
      masked_image_modeling/engine_finetune.py
  33. 0 64
      masked_image_modeling/engine_pretrain.py
  34. 0 175
      masked_image_modeling/main_finetune.py
  35. 0 211
      masked_image_modeling/main_pretrain.py
  36. 0 9
      masked_image_modeling/models/__init__.py
  37. 0 3
      masked_image_modeling/models/vit/__init__.py
  38. 0 45
      masked_image_modeling/models/vit/build.py
  39. 0 186
      masked_image_modeling/models/vit/modules.py
  40. 0 96
      masked_image_modeling/models/vit/pos_embed.py
  41. 0 180
      masked_image_modeling/models/vit/vit.py
  42. 0 28
      masked_image_modeling/models/vit/vit_cls.py
  43. 0 399
      masked_image_modeling/models/vit/vit_mae.py
  44. 0 5
      masked_image_modeling/requirements.txt
  45. 0 37
      masked_image_modeling/utils/lr_scheduler.py
  46. 0 231
      masked_image_modeling/utils/misc.py
  47. 0 25
      masked_image_modeling/utils/optimizer.py
  48. 1 0
      yolo/.gitignore
  49. 0 110
      yolo/benchmark.py
  50. 1 1
      yolo/config/__init__.py
  51. 0 0
      yolo/config/detr_config.py
  52. 0 0
      yolo/config/fcos_config.py
  53. 0 0
      yolo/config/yolof_config.py
  54. 0 0
      yolo/config/yolov10_config.py
  55. 0 0
      yolo/config/yolov11_config.py
  56. 0 0
      yolo/config/yolov4_config.py
  57. 0 0
      yolo/config/yolov7_config.py
  58. 0 2
      yolo/config/yolov9_config.py
  59. 1 1
      yolo/models/__init__.py
  60. 0 0
      yolo/models/detr/build.py
  61. 0 0
      yolo/models/detr/detr.py
  62. 0 0
      yolo/models/detr/detr_backbone.py
  63. 0 0
      yolo/models/detr/detr_transformer.py
  64. 0 0
      yolo/models/detr/loss.py
  65. 0 0
      yolo/models/detr/matcher.py
  66. 148 0
      yolo/models/detr/modules.py
  67. 0 0
      yolo/models/fcos/build.py
  68. 0 0
      yolo/models/fcos/fcos.py
  69. 0 0
      yolo/models/fcos/fcos_backbone.py
  70. 68 0
      yolo/models/fcos/fcos_fpn.py
  71. 186 0
      yolo/models/fcos/fcos_head.py
  72. 290 0
      yolo/models/fcos/loss.py
  73. 378 0
      yolo/models/fcos/matcher.py
  74. 148 0
      yolo/models/fcos/modules.py
  75. 187 0
      yolo/models/fcos/resnet.py
  76. 0 24
      yolo/models/gelan/build.py
  77. 0 165
      yolo/models/gelan/gelan.py
  78. 0 198
      yolo/models/gelan/gelan_backbone.py
  79. 0 312
      yolo/models/gelan/gelan_basic.py
  80. 0 176
      yolo/models/gelan/gelan_head.py
  81. 0 76
      yolo/models/gelan/gelan_neck.py
  82. 0 158
      yolo/models/gelan/gelan_pafpn.py
  83. 0 155
      yolo/models/gelan/gelan_pred.py
  84. 0 187
      yolo/models/gelan/loss.py
  85. 0 199
      yolo/models/gelan/matcher.py
  86. 0 0
      yolo/models/yolof/build.py
  87. 144 0
      yolo/models/yolof/loss.py
  88. 103 0
      yolo/models/yolof/matcher.py
  89. 148 0
      yolo/models/yolof/modules.py
  90. 187 0
      yolo/models/yolof/resnet.py
  91. 0 0
      yolo/models/yolof/yolof.py
  92. 0 0
      yolo/models/yolof/yolof_backbone.py
  93. 185 0
      yolo/models/yolof/yolof_decoder.py
  94. 72 0
      yolo/models/yolof/yolof_encoder.py
  95. 56 0
      yolo/models/yolov10/README.md
  96. 66 0
      yolo/models/yolov10/build.py
  97. 212 0
      yolo/models/yolov10/loss.py
  98. 187 0
      yolo/models/yolov10/matcher.py
  99. 338 0
      yolo/models/yolov10/modules.py
  100. 302 0
      yolo/models/yolov10/yolov10.py

+ 0 - 0
image_classification/.gitignore → classify/.gitignore


+ 0 - 0
image_classification/README.md → classify/README.md


+ 0 - 0
image_classification/data/__init__.py → classify/data/__init__.py


+ 0 - 0
image_classification/data/cifar.py → classify/data/cifar.py


+ 0 - 0
image_classification/data/custom.py → classify/data/custom.py


+ 0 - 0
image_classification/data/mnist.py → classify/data/mnist.py


+ 0 - 0
image_classification/engine.py → classify/engine.py


+ 0 - 0
image_classification/main.py → classify/main.py


+ 0 - 3
image_classification/models/__init__.py → classify/models/__init__.py

@@ -1,6 +1,5 @@
 from .mlp.build     import build_mlp
 from .mlp.build     import build_mlp
 from .convnet.build import build_convnet
 from .convnet.build import build_convnet
-from .resnet.build  import build_resnet
 from .vit.build     import build_vit
 from .vit.build     import build_vit
 
 
 
 
@@ -10,8 +9,6 @@ def build_model(args):
         model = build_mlp(args)
         model = build_mlp(args)
     elif 'convnet' in args.model:
     elif 'convnet' in args.model:
         model = build_convnet(args)
         model = build_convnet(args)
-    elif 'resnet' in args.model:
-        model = build_resnet(args)
     elif 'vit' in args.model:
     elif 'vit' in args.model:
         model = build_vit(args)
         model = build_vit(args)
     else:
     else:

+ 0 - 0
image_classification/models/convnet/build.py → classify/models/convnet/build.py


+ 0 - 0
image_classification/models/convnet/convnet.py → classify/models/convnet/convnet.py


+ 0 - 0
image_classification/models/convnet/modules.py → classify/models/convnet/modules.py


+ 0 - 0
image_classification/models/mlp/build.py → classify/models/mlp/build.py


+ 0 - 0
image_classification/models/mlp/mlp.py → classify/models/mlp/mlp.py


+ 0 - 0
image_classification/models/mlp/modules.py → classify/models/mlp/modules.py


+ 0 - 0
image_classification/models/vit/build.py → classify/models/vit/build.py


+ 0 - 0
image_classification/models/vit/modules.py → classify/models/vit/modules.py


+ 0 - 0
image_classification/models/vit/vit.py → classify/models/vit/vit.py


+ 0 - 0
image_classification/requirements.txt → classify/requirements.txt


+ 0 - 0
image_classification/utils/__init__.py → classify/utils/__init__.py


+ 0 - 0
image_classification/utils/lr_scheduler.py → classify/utils/lr_scheduler.py


+ 0 - 0
image_classification/utils/misc.py → classify/utils/misc.py


+ 0 - 0
image_classification/utils/optimzer.py → classify/utils/optimzer.py


+ 0 - 21
image_classification/models/resnet/build.py

@@ -1,21 +0,0 @@
-from .resnet import ResNet
-from .modules import PlainResBlock, BottleneckResBlock
-
-
-def build_resnet(args):
-    if args.model == 'resnet18':
-        model = ResNet(in_dim=args.img_dim,
-                       block=PlainResBlock,
-                       expansion=1.0,
-                       num_blocks=[2, 2, 2, 2],
-                       )
-    elif args.model == 'resnet50':
-        model = ResNet(in_dim=args.img_dim,
-                       block=BottleneckResBlock,
-                       expansion=4.0,
-                       num_blocks=[3, 4, 6, 3],
-                       )
-    else:
-        raise NotImplementedError("Unknown resnet: {}".format(args.model))
-    
-    return model

+ 0 - 164
image_classification/models/resnet/modules.py

@@ -1,164 +0,0 @@
-import torch
-import torch.nn as nn
-
-
-def get_activation(act_type=None):
-    if   act_type == 'sigmoid':
-        return nn.Sigmoid()
-    elif act_type == 'relu':
-        return nn.ReLU(inplace=True)
-    elif act_type == 'lrelu':
-        return nn.LeakyReLU(0.1, inplace=True)
-    elif act_type == 'mish':
-        return nn.Mish(inplace=True)
-    elif act_type == 'silu':
-        return nn.SiLU(inplace=True)
-    elif act_type is None:
-        return nn.Identity()
-    else:
-        raise NotImplementedError
-        
-def get_norm(norm_type, dim):
-    if   norm_type == 'bn':
-        return nn.BatchNorm2d(dim)
-    elif norm_type == 'ln':
-        return LayerNorm2d(dim)
-    elif norm_type == 'gn':
-        return nn.GroupNorm(num_groups=32, num_channels=dim)
-    elif norm_type is None:
-        return nn.Identity()
-    else:
-        raise NotImplementedError
-
-class LayerNorm2d(nn.Module):
-    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
-        super().__init__()
-        self.weight = nn.Parameter(torch.ones(num_channels))
-        self.bias = nn.Parameter(torch.zeros(num_channels))
-        self.eps = eps
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        u = x.mean(1, keepdim=True)
-        s = (x - u).pow(2).mean(1, keepdim=True)
-        x = (x - u) / torch.sqrt(s + self.eps)
-        x = self.weight[:, None, None] * x + self.bias[:, None, None]
-        
-        return x
-    
-class ConvModule(nn.Module):
-    def __init__(self,
-                 in_dim      :int,
-                 out_dim     :int,
-                 kernel_size :int  = 1,
-                 padding     :int  = 0,
-                 stride      :int  = 1,
-                 act_type    :str  = "relu",
-                 norm_type   :str  = "bn",
-                 depthwise   :bool = False) -> None:
-        super().__init__()
-        use_bias = False if norm_type is not None else True
-        self.depthwise = depthwise
-        if not depthwise:
-            self.conv = nn.Conv2d(in_channels=in_dim, out_channels=out_dim,
-                                kernel_size=kernel_size, padding=padding, stride=stride,
-                                bias=use_bias)
-            self.norm  = get_norm(norm_type, out_dim)
-        else:
-            self.conv1 = nn.Conv2d(in_channels=in_dim, out_channels=in_dim,
-                                   kernel_size=kernel_size, padding=padding, stride=stride, groups=in_dim,
-                                   bias=use_bias)
-            self.norm1 = get_norm(norm_type, in_dim)
-            self.conv2 = nn.Conv2d(in_channels=in_dim, out_channels=out_dim,
-                                   kernel_size=1, padding=0, stride=1,
-                                   bias=use_bias)
-            self.norm2 = get_norm(norm_type, out_dim)
-        self.act   = get_activation(act_type)
-
-    def forward(self, x):
-        if self.depthwise:
-            x = self.norm1(self.conv1(x))
-            x = self.act(self.norm2(self.conv2(x)))
-        else:
-            x = self.act(self.norm(self.conv(x)))
-
-        return x
-
-
-# -------------- ResNet's modules --------------
-class PlainResBlock(nn.Module):
-    def __init__(self, in_dim, inter_dim, out_dim, stride=1):
-        super().__init__()
-        # -------- Basic parameters --------
-        self.in_dim = in_dim
-        self.out_dim = out_dim
-        self.inter_dim = inter_dim
-        self.stride = stride
-        self.downsample = stride > 1 or in_dim != out_dim
-
-        # -------- Model parameters --------
-        self.conv_layer_1 = ConvModule(in_dim, inter_dim,
-                                       kernel_size=3, padding=1, stride=stride,
-                                       act_type='relu', norm_type='bn', depthwise=False)
-        self.conv_layer_2 = ConvModule(inter_dim, out_dim,
-                                       kernel_size=3, padding=1, stride=1,
-                                       act_type=None, norm_type='bn', depthwise=False)
-        self.out_act = nn.ReLU(inplace=True)
-
-        if self.downsample:
-            self.res_layer = ConvModule(in_dim, out_dim,
-                                       kernel_size=1, padding=0, stride=stride,
-                                       act_type=None, norm_type='bn', depthwise=False)
-        else:
-            self.res_layer = nn.Identity()
-
-    def forward(self, x):
-        out = self.conv_layer_1(x)
-        out = self.conv_layer_2(out)
-
-        x = self.res_layer(x)
-
-        out = x + out
-        out = self.out_act(out)
-
-        return out
-
-class BottleneckResBlock(nn.Module):
-    def __init__(self, in_dim, inter_dim, out_dim, stride=1):
-        super().__init__()
-        # -------- Basic parameters --------
-        self.in_dim = in_dim
-        self.out_dim = out_dim
-        self.stride = stride
-        self.downsample = stride > 1 or in_dim != out_dim
-
-        # -------- Model parameters --------
-        self.conv_layer_1 = ConvModule(in_dim, inter_dim,
-                                       kernel_size=1, padding=0, stride=1,
-                                       act_type='relu', norm_type='bn', depthwise=False)
-        self.conv_layer_2 = ConvModule(inter_dim, inter_dim,
-                                       kernel_size=3, padding=1, stride=stride,
-                                       act_type='relu', norm_type='bn', depthwise=False)
-        self.conv_layer_3 = ConvModule(inter_dim, out_dim,
-                                       kernel_size=1, padding=0, stride=1,
-                                       act_type=None, norm_type='bn', depthwise=False)
-        self.out_act = nn.ReLU(inplace=True)
-
-        if self.downsample:
-            self.res_layer = ConvModule(in_dim, out_dim,
-                                       kernel_size=1, padding=0, stride=stride,
-                                       act_type=None, norm_type='bn', depthwise=False)
-        else:
-            self.res_layer = nn.Identity()
-
-    def forward(self, x):
-        out = self.conv_layer_1(x)
-        out = self.conv_layer_2(out)
-        out = self.conv_layer_3(out)
-
-        x = self.res_layer(x)
-
-        out = x + out
-        out = self.out_act(out)
-
-        return out
-

+ 0 - 110
image_classification/models/resnet/resnet.py

@@ -1,110 +0,0 @@
-import torch
-import torch.nn as nn
-
-try:
-    from .modules import ConvModule, PlainResBlock, BottleneckResBlock
-except:
-    from  modules import ConvModule, PlainResBlock, BottleneckResBlock
-
-
-class ResNet(nn.Module):
-    def __init__(self,
-                 in_dim,
-                 block,
-                 expansion = 1.0,
-                 num_blocks = [2, 2, 2, 2],
-                 num_classes = 1000,
-                 ) -> None:
-        super().__init__()
-        # ----------- Basic parameters -----------
-        self.expansion = expansion
-        self.num_blocks = num_blocks
-        self.feat_dims  = [64,                      # C2 level
-                           round(64 * expansion),   # C2 level
-                           round(128 * expansion),  # C3 level
-                           round(256 * expansion),  # C4 level
-                           round(512 * expansion),  # C5 level
-                           ]
-        # ----------- Model parameters -----------
-        ## Backbone
-        self.layer_1 = nn.Sequential(
-            ConvModule(in_dim, self.feat_dims[0],
-                       kernel_size=7, padding=3, stride=2,
-                       act_type='relu', norm_type='bn', depthwise=False),
-            nn.MaxPool2d(kernel_size=(3, 3), padding=(1, 1), stride=(2, 2))
-        )
-        self.layer_2 = self.make_layer(block, self.feat_dims[0], self.feat_dims[1], depth=num_blocks[0], downsample=False)
-        self.layer_3 = self.make_layer(block, self.feat_dims[1], self.feat_dims[2], depth=num_blocks[1], downsample=True)
-        self.layer_4 = self.make_layer(block, self.feat_dims[2], self.feat_dims[3], depth=num_blocks[2], downsample=True)
-        self.layer_5 = self.make_layer(block, self.feat_dims[3], self.feat_dims[4], depth=num_blocks[3], downsample=True)
-
-        ## Classifier
-        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
-        self.fc      = nn.Linear(self.feat_dims[4] , num_classes)
-        
-    def make_layer(self, block, in_dim, out_dim, depth=1, downsample=False):
-        stage_blocks = []
-        for i in range(depth):
-            if i == 0:
-                stride = 2 if downsample else 1
-                inter_dim = round(out_dim / self.expansion)
-                stage_blocks.append(block(in_dim, inter_dim, out_dim, stride))
-            else:
-                stride = 1
-                inter_dim = round(out_dim / self.expansion)
-                stage_blocks.append(block(out_dim, inter_dim, out_dim, stride))
-        
-        layers = nn.Sequential(*stage_blocks)
-
-        return layers
-    
-    def forward(self, x):
-        x = self.layer_1(x)
-        x = self.layer_2(x)
-        x = self.layer_3(x)
-        x = self.layer_4(x)
-        x = self.layer_5(x)
-
-        x = self.avgpool(x)
-        x = x.flatten(1)
-        x = self.fc(x)
-
-        return x
-
-
-def build_resnet(model_name='resnet18', img_dim=3):
-    if model_name == 'resnet18':
-        model = ResNet(in_dim=img_dim,
-                       block=PlainResBlock,
-                       expansion=1.0,
-                       num_blocks=[2, 2, 2, 2],
-                       )
-    elif model_name == 'resnet50':
-        model = ResNet(in_dim=img_dim,
-                       block=BottleneckResBlock,
-                       expansion=4.0,
-                       num_blocks=[3, 4, 6, 3],
-                       )
-    else:
-        raise NotImplementedError("Unknown resnet: {}".format(model_name))
-    
-    return model
-
-
-if __name__=='__main__':
-    import time
-
-    # 构建ResNet模型
-    model = build_resnet(model_name='resnet18')
-
-    # 打印模型结构
-    print(model)
-
-    # 随即成生数据
-    x = torch.randn(1, 3, 224, 224)
-
-    # 模型前向推理
-    t0 = time.time()
-    output = model(x)
-    t1 = time.time()
-    print('Time: ', t1 - t0)

+ 0 - 12
masked_image_modeling/.gitignore

@@ -1,12 +0,0 @@
-*.pt
-*.pth
-*.pkl
-*.onnx
-*.pyc
-*.zip
-weights
-__pycache__
-data/cifar/
-data/cifar_data/
-data/mnist_data/
-vis_results/

+ 0 - 73
masked_image_modeling/README.md

@@ -1,73 +0,0 @@
-# Masked AutoEncoder
-
-## 1. Pretrain
-We have kindly provided the bash script `main_pretrain.sh` file for pretraining. You can modify some hyperparameters in the script file according to your own needs.
-
-```Shell
-cd Vision-Pretraining-Tutorial/masked_image_modeling/
-python main_pretrain.py --cuda \
-                        --dataset cifar10 \
-                        --model vit_t \
-                        --mask_ratio 0.75 \
-                        --batch_size 128 \
-                        --optimizer adamw \
-                        --weight_decay 0.05 \
-                        --lr_scheduler cosine \
-                        --base_lr 0.00015 \
-                        --min_lr 0.0 \
-                        --max_epoch 400 \
-                        --eval_epoch 20
-```
-
-## 2. Finetune
-We have kindly provided the bash script `main_finetune.sh` file for finetuning. You can modify some hyperparameters in the script file according to your own needs.
-
-```Shell
-cd Vision-Pretraining-Tutorial/masked_image_modeling/
-python main_finetune.py --cuda \
-                        --dataset cifar10 \
-                        --model vit_t \
-                        --batch_size 256 \
-                        --optimizer adamw \
-                        --weight_decay 0.05 \
-                        --base_lr 0.0005 \
-                        --min_lr 0.000001 \
-                        --max_epoch 100 \
-                        --wp_epoch 5 \
-                        --eval_epoch 5 \
-                        --pretrained path/to/vit_t.pth
-```
-## 3. Evaluate 
-- Evaluate the `top1 & top5` accuracy of `ViT-Tiny` on CIFAR10 dataset:
-```Shell
-python main_finetune.py --cuda \
-                        --dataset cifar10 \
-                        -m vit_t \
-                        --batch_size 256 \
-                        --eval \
-                        --resume path/to/vit_t_cifar10.pth
-```
-
-
-## 4. Visualize Image Reconstruction
-- Evaluate `ViT-Tiny` pretrained by MAE framework on CIFAR10 dataset:
-```Shell
-python main_pretrain.py --cuda \
-                        --dataset cifar10 \
-                        -m vit_t \
-                        --resume path/to/mae_vit_t_cifar10.pth \
-                        --eval \
-                        --batch_size 1
-```
-
-
-## 5. Experiments
-- On CIFAR10
-
-| Method |  Model  | Epoch | Top 1    | Weight |  MAE weight  |
-|  :---: |  :---:  | :---: | :---:    | :---:  |    :---:     |
-|  MAE   |  ViT-T  | 100   |   91.2   | [ckpt](https://github.com/yjh0410/MAE/releases/download/checkpoints/ViT-T_Cifar10.pth) | [ckpt](https://github.com/yjh0410/MAE/releases/download/checkpoints/MAE_ViT-T_Cifar10.pth) |
-
-
-## 6. Acknowledgment
-Thank you to **Kaiming He** for his inspiring work on [MAE](http://openaccess.thecvf.com/content/CVPR2022/papers/He_Masked_Autoencoders_Are_Scalable_Vision_Learners_CVPR_2022_paper.pdf). His research effectively elucidates the semantic distinctions between vision and language, offering valuable insights for subsequent vision-related studies. I would also like to express my gratitude for the official source code of [MAE](https://github.com/facebookresearch/mae). Additionally, I appreciate the efforts of [**IcarusWizard**](https://github.com/IcarusWizard) for reproducing the [MAE](https://github.com/IcarusWizard/MAE) implementation.

+ 0 - 38
masked_image_modeling/data/__init__.py

@@ -1,38 +0,0 @@
-import torch
-
-from .cifar import CifarDataset
-from .custom import CustomDataset
-
-
-def build_dataset(args, is_train=False):
-    # ----------------- CIFAR dataset -----------------
-    if args.dataset == 'cifar10':
-        args.num_classes = 10
-        args.img_dim = 3
-        args.img_size = 32
-        args.patch_size = 4
-        return CifarDataset(is_train)
-        
-    # ----------------- Customed dataset -----------------
-    elif args.dataset == 'custom':
-        assert args.num_classes is not None and isinstance(args.num_classes, int)
-        args.img_size = 224
-        args.patch_size = 16
-        return CustomDataset(args, is_train)
-    
-    else:
-        print("Unknown dataset: {}".format(args.dataset))
-    
-
-def build_dataloader(args, dataset, is_train=False):
-    if is_train:
-        sampler = torch.utils.data.RandomSampler(dataset)
-        batch_sampler_train = torch.utils.data.BatchSampler(
-            sampler, args.batch_size, drop_last=True if is_train else False)
-        dataloader = torch.utils.data.DataLoader(
-            dataset, batch_sampler=batch_sampler_train, num_workers=args.num_workers, pin_memory=True)
-    else:
-        dataloader = torch.utils.data.DataLoader(
-            dataset=dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
-
-    return dataloader

+ 0 - 65
masked_image_modeling/data/cifar.py

@@ -1,65 +0,0 @@
-import os
-import numpy as np
-import torch.utils.data as data
-import torchvision.transforms as T
-from torchvision.datasets import CIFAR10
-
-
-class CifarDataset(data.Dataset):
-    def __init__(self, is_train=False):
-        super().__init__()
-        # ----------------- basic parameters -----------------
-        self.pixel_mean = [0.5, 0.5, 0.5]
-        self.pixel_std =  [0.5, 0.5, 0.5]
-        self.is_train  = is_train
-        self.image_set = 'train' if is_train else 'val'
-        # ----------------- dataset & transforms -----------------
-        self.transform = self.build_transform()
-        path = os.path.dirname(os.path.abspath(__file__))
-        if is_train:
-            self.dataset = CIFAR10(os.path.join(path, 'cifar_data/'), train=True, download=True, transform=self.transform)
-        else:
-            self.dataset = CIFAR10(os.path.join(path, 'cifar_data/'), train=False, download=True, transform=self.transform)
-
-    def __len__(self):
-        return len(self.dataset)
-    
-    def __getitem__(self, index):
-        image, target = self.dataset[index]
-            
-        return image, target
-    
-    def pull_image(self, index):
-        # laod data
-        image, target = self.dataset[index]
-
-        # denormalize image
-        image = image.permute(1, 2, 0).numpy()
-        image = (image * self.pixel_std + self.pixel_mean) * 255.
-        image = image.astype(np.uint8)
-        image = image.copy()
-
-        return image, target
-
-    def build_transform(self):
-        if self.is_train:
-            transforms = T.Compose([T.ToTensor(), T.Normalize(0.5, 0.5)])
-        else:
-            transforms = T.Compose([T.ToTensor(), T.Normalize(0.5, 0.5)])
-
-        return transforms
-
-if __name__ == "__main__":
-    import cv2
-    
-    # dataset
-    dataset = CifarDataset(is_train=True)  
-    print('Dataset size: ', len(dataset))
-
-    for i in range(len(dataset)):
-        image, target = dataset.pull_image(i)
-        # to BGR
-        image = image[..., (2, 1, 0)]
-
-        cv2.imshow('image', image)
-        cv2.waitKey(0)

+ 0 - 87
masked_image_modeling/data/custom.py

@@ -1,87 +0,0 @@
-import os
-import PIL
-import numpy as np
-import torch.utils.data as data
-import torchvision.transforms as T
-from torchvision.datasets import ImageFolder
-
-
-class CustomDataset(data.Dataset):
-    def __init__(self, args, is_train=False):
-        super().__init__()
-        # ----------------- basic parameters -----------------
-        self.args = args
-        self.is_train   = is_train
-        self.pixel_mean = [0.485, 0.456, 0.406]
-        self.pixel_std  = [0.229, 0.224, 0.225]
-        print("Pixel mean: {}".format(self.pixel_mean))
-        print("Pixel std:  {}".format(self.pixel_std))
-        self.image_set = 'train' if is_train else 'val'
-        self.data_path = os.path.join(args.root, self.image_set)
-        # ----------------- dataset & transforms -----------------
-        self.transform = self.build_transform()
-        self.dataset = ImageFolder(root=self.data_path, transform=self.transform)
-
-    def __len__(self):
-        return len(self.dataset)
-    
-    def __getitem__(self, index):
-        image, target = self.dataset[index]
-
-        return image, target
-    
-    def pull_image(self, index):
-        # laod data
-        image, target = self.dataset[index]
-
-        # denormalize image
-        image = image.permute(1, 2, 0).numpy()
-        image = (image * self.pixel_std + self.pixel_mean) * 255.
-        image = image.astype(np.uint8)
-        image = image.copy()
-
-        return image, target
-
-    def build_transform(self):
-        if self.is_train:
-            transforms = T.Compose([
-                            T.RandomResizedCrop(224),
-                            T.RandomHorizontalFlip(0.5),
-                            T.ToTensor(),
-                            T.Normalize(self.pixel_mean,
-                                        self.pixel_std)])
-        else:
-            transforms = T.Compose([
-                T.Resize(224, interpolation=PIL.Image.BICUBIC),
-                T.CenterCrop(224),
-                T.ToTensor(),
-                T.Normalize(self.pixel_mean, self.pixel_std),
-            ])
-
-        return transforms
-
-
-if __name__ == "__main__":
-    import cv2
-    import argparse
-    
-    parser = argparse.ArgumentParser(description='Custom-Dataset')
-
-    # opt
-    parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/classification/dataset/Animals/',
-                        help='data root')
-    parser.add_argument('--img_size', default=224, type=int,
-                        help='input image size.')
-    args = parser.parse_args()
-  
-    # Dataset
-    dataset = CustomDataset(args, is_train=True)  
-    print('Dataset size: ', len(dataset))
-
-    for i in range(len(dataset)):
-        image, target = dataset.pull_image(i)
-        # to BGR
-        image = image[..., (2, 1, 0)]
-
-        cv2.imshow('image', image)
-        cv2.waitKey(0)

+ 0 - 107
masked_image_modeling/engine_finetune.py

@@ -1,107 +0,0 @@
-import sys
-import math
-import torch
-
-from utils.misc import MetricLogger, SmoothedValue, accuracy
-
-
-def train_one_epoch(args,
-                    device,
-                    model,
-                    data_loader,
-                    optimizer,
-                    epoch,
-                    lr_scheduler_warmup,
-                    criterion,
-                    ):
-    model.train(True)
-    metric_logger = MetricLogger(delimiter="  ")
-    metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
-    header = 'Epoch: [{}]'.format(epoch)
-    print_freq = 20
-    epoch_size = len(data_loader)
-
-    optimizer.zero_grad()
-
-    # train one epoch
-    for iter_i, (images, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
-        ni = iter_i + epoch * epoch_size
-        nw = args.wp_epoch * epoch_size
-
-        # Warmup
-        if nw > 0 and ni < nw:
-            lr_scheduler_warmup(ni, optimizer)
-        elif ni == nw:
-            print("Warmup stage is over.")
-            lr_scheduler_warmup.set_lr(optimizer, args.base_lr)
-
-        # To device
-        images = images.to(device, non_blocking=True)
-        targets = targets.to(device, non_blocking=True)
-
-        # Inference
-        output = model(images)
-
-        # Compute loss
-        loss = criterion(output, targets)
-
-        # Check loss
-        loss_value = loss.item()
-        if not math.isfinite(loss_value):
-            print("Loss is {}, stopping training".format(loss_value))
-            sys.exit(1)
-
-        # Backward
-        loss.backward()
-
-        # Optimize
-        optimizer.step()
-        optimizer.zero_grad()
-
-        # Logs
-        lr = optimizer.param_groups[0]["lr"]
-        metric_logger.update(loss=loss_value)
-        metric_logger.update(lr=lr)
-
-    # gather the stats from all processes
-    print("Averaged stats: {}".format(metric_logger))
-
-    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
-
-
-@torch.no_grad()
-def evaluate(data_loader, model, device):
-    criterion = torch.nn.CrossEntropyLoss()
-
-    metric_logger = MetricLogger(delimiter="  ")
-    header = 'Test:'
-
-    # switch to evaluation mode
-    model.eval()
-
-    for batch in metric_logger.log_every(data_loader, 10, header):
-        images = batch[0]
-        target = batch[1]
-        images = images.to(device, non_blocking=True)
-        target = target.to(device, non_blocking=True)
-
-        # Inference
-        output = model(images)
-
-        # Compute loss
-        loss = criterion(output, target)
-
-        # Compute accuracy
-        acc1, acc5 = accuracy(output, target, topk=(1, 5))
-
-        batch_size = images.shape[0]
-        metric_logger.update(loss=loss.item())
-        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
-        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
-
-    # gather the stats from all processes
-    print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
-          .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss),
-          )
-
-    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

+ 0 - 64
masked_image_modeling/engine_pretrain.py

@@ -1,64 +0,0 @@
-import sys
-import math
-
-from utils.misc import MetricLogger, SmoothedValue
-
-
-def train_one_epoch(args,
-                    device,
-                    model,
-                    data_loader,
-                    optimizer,
-                    epoch,
-                    lr_scheduler_warmup,
-                    ):
-    model.train(True)
-    metric_logger = MetricLogger(delimiter="  ")
-    metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
-    header = 'Epoch: [{}]'.format(epoch)
-    print_freq = 20
-    epoch_size = len(data_loader)
-
-    # Train one epoch
-    for iter_i, (images, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
-        ni = iter_i + epoch * epoch_size
-        nw = args.wp_epoch * epoch_size
-        
-        # Warmup
-        if nw > 0 and ni < nw:
-            lr_scheduler_warmup(ni, optimizer)
-        elif ni == nw:
-            print("Warmup stage is over.")
-            lr_scheduler_warmup.set_lr(optimizer, args.base_lr)
-
-        # To device
-        images = images.to(device, non_blocking=True)
-
-        # Inference
-        output = model(images)
-
-        # Compute loss
-        loss = output["loss"]
-
-        # Check loss
-        loss_value = loss.item()
-        if not math.isfinite(loss_value):
-            print("Loss is {}, stopping training".format(loss_value))
-            sys.exit(1)
-
-        # Backward
-        loss.backward()
-
-        # Optimize
-        optimizer.step()
-        optimizer.zero_grad()
-
-        # Logs
-        lr = optimizer.param_groups[0]["lr"]
-        metric_logger.update(loss=loss_value)
-        metric_logger.update(lr=lr)
-
-    # gather the stats from all processes
-    print("Averaged stats: {}".format(metric_logger))
-
-    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

+ 0 - 175
masked_image_modeling/main_finetune.py

@@ -1,175 +0,0 @@
-import os
-import time
-import argparse
-import datetime
-
-# ---------------- Torch compoments ----------------
-import torch
-import torch.backends.cudnn as cudnn
-
-# ---------------- Dataset compoments ----------------
-from data import build_dataset, build_dataloader
-
-# ---------------- Model compoments ----------------
-from models import build_model
-
-# ---------------- Utils compoments ----------------
-from utils.misc import setup_seed, load_model, save_model
-from utils.optimizer import build_optimizer
-from utils.lr_scheduler import build_lr_scheduler, LinearWarmUpLrScheduler
-
-# ---------------- Training engine ----------------
-from engine_finetune import train_one_epoch, evaluate
-
-
-def parse_args():
-    parser = argparse.ArgumentParser()
-    # Input
-    parser.add_argument('--img_dim', type=int, default=3,
-                        help='3 for RGB; 1 for Gray.')    
-    parser.add_argument('--patch_size', type=int, default=16,
-                        help='patch_size.')
-    # Basic
-    parser.add_argument('--seed', type=int, default=42,
-                        help='random seed.')
-    parser.add_argument('--cuda', action='store_true', default=False,
-                        help='use cuda')
-    parser.add_argument('--batch_size', type=int, default=256,
-                        help='batch size on all GPUs')
-    parser.add_argument('--num_workers', type=int, default=4,
-                        help='number of workers')
-    parser.add_argument('--path_to_save', type=str, default='weights/',
-                        help='path to save trained model.')
-    parser.add_argument('--eval', action='store_true', default=False,
-                        help='evaluate model.')
-    # Epoch
-    parser.add_argument('--wp_epoch', type=int, default=5, 
-                        help='warmup epoch')
-    parser.add_argument('--start_epoch', type=int, default=0, 
-                        help='start epoch')
-    parser.add_argument('--max_epoch', type=int, default=50, 
-                        help='max epoch')
-    parser.add_argument('--eval_epoch', type=int, default=5, 
-                        help='max epoch')
-    # Dataset
-    parser.add_argument('--dataset', type=str, default='cifar10',
-                        help='dataset name')
-    parser.add_argument('--root', type=str, default='/mnt/share/ssd2/dataset',
-                        help='path to dataset folder')
-    parser.add_argument('--num_classes', type=int, default=None, 
-                        help='number of classes.')
-    # Model
-    parser.add_argument('-m', '--model', type=str, default='vit_t',
-                        help='model name')
-    parser.add_argument('--pretrained', default=None, type=str,
-                        help='load pretrained weight.')
-    parser.add_argument('--resume', default=None, type=str,
-                        help='keep training')
-    parser.add_argument('--drop_path', type=float, default=0.1,
-                        help='drop_path')
-    # Optimizer
-    parser.add_argument('-opt', '--optimizer', type=str, default='adamw',
-                        help='sgd, adam')
-    parser.add_argument('-wd', '--weight_decay', type=float, default=0.05,
-                        help='weight decay')
-    parser.add_argument('--base_lr', type=float, default=0.001,
-                        help='learning rate for training model')
-    parser.add_argument('--min_lr', type=float, default=0,
-                        help='the final lr')
-    # Lr scheduler
-    parser.add_argument('-lrs', '--lr_scheduler', type=str, default='cosine',
-                        help='step, cosine')
-
-    return parser.parse_args()
-
-    
-def main():
-    args = parse_args()
-    # set random seed
-    setup_seed(args.seed)
-
-    # Path to save model
-    path_to_save = os.path.join(args.path_to_save, args.dataset, "finetune", args.model)
-    os.makedirs(path_to_save, exist_ok=True)
-    args.output_dir = path_to_save
-
-    # ------------------------- Build CUDA -------------------------
-    if args.cuda:
-        if torch.cuda.is_available():
-            cudnn.benchmark = True
-            device = torch.device("cuda")
-        else:
-            print('There is no available GPU.')
-            args.cuda = False
-            device = torch.device("cpu")
-    else:
-        device = torch.device("cpu")
-
-    # ------------------------- Build Dataset -------------------------
-    train_dataset = build_dataset(args, is_train=True)
-    val_dataset   = build_dataset(args, is_train=False)
-
-    # ------------------------- Build Dataloader -------------------------
-    train_dataloader = build_dataloader(args, train_dataset, is_train=True)
-    val_dataloader   = build_dataloader(args, val_dataset,   is_train=False)
-
-    print('=================== Dataset Information ===================')
-    print('Train dataset size : ', len(train_dataset))
-    print('Val dataset size   : ', len(val_dataset))
-
-    # ------------------------- Build Model -------------------------
-    model = build_model(args, model_type='cls')
-    model.train().to(device)
-    print(model)
-
-    # ------------------------- Build Optimzier -------------------------
-    optimizer = build_optimizer(args, model)
-
-    # ------------------------- Build Lr Scheduler -------------------------
-    lr_scheduler_warmup = LinearWarmUpLrScheduler(args.base_lr, wp_iter=args.wp_epoch * len(train_dataloader))
-    lr_scheduler = build_lr_scheduler(args, optimizer)
-
-
-    # ------------------------- Build Criterion -------------------------
-    criterion = torch.nn.CrossEntropyLoss()
-    load_model(args, model, optimizer, lr_scheduler)
-
-    # ------------------------- Eval before Train Pipeline -------------------------
-    if args.eval:
-        print('evaluating ...')
-        test_stats = evaluate(val_dataloader, model, device)
-        print('Eval Results: [loss: %.2f][acc1: %.2f][acc5 : %.2f]' %
-                (test_stats['loss'], test_stats['acc1'], test_stats['acc5']), flush=True)
-        return
-
-    # ------------------------- Training Pipeline -------------------------
-    start_time = time.time()
-    max_accuracy = -1.0
-    print("=============== Start training for {} epochs ===============".format(args.max_epoch))
-    for epoch in range(args.start_epoch, args.max_epoch):
-        # Train one epoch
-        train_one_epoch(args, device, model, train_dataloader, optimizer,
-                        epoch, lr_scheduler_warmup, criterion)
-
-        # LR scheduler
-        if (epoch + 1) > args.wp_epoch:
-            lr_scheduler.step()
-
-        # Evaluate
-        if (epoch % args.eval_epoch) == 0 or (epoch + 1 == args.max_epoch):
-            test_stats = evaluate(val_dataloader, model, device)
-            print(f"Accuracy of the network on the {len(val_dataset)} test images: {test_stats['acc1']:.1f}%")
-            max_accuracy = max(max_accuracy, test_stats["acc1"])
-            print(f'Max accuracy: {max_accuracy:.2f}%')
-
-            # Save model
-            print('- saving the model after {} epochs ...'.format(epoch))
-            save_model(args, epoch, model, optimizer, lr_scheduler, acc1=max_accuracy)
-
-    total_time = time.time() - start_time
-    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
-    print('Training time {}'.format(total_time_str))
-
-
-if __name__ == "__main__":
-    main()

+ 0 - 211
masked_image_modeling/main_pretrain.py

@@ -1,211 +0,0 @@
-import os
-import cv2
-import time
-import datetime
-import argparse
-import numpy as np
-
-# ---------------- Torch compoments ----------------
-import torch
-import torch.backends.cudnn as cudnn
-
-# ---------------- Dataset compoments ----------------
-from data import build_dataset, build_dataloader
-from models import build_model
-
-# ---------------- Utils compoments ----------------
-from utils.misc import setup_seed
-from utils.misc import load_model, save_model, unpatchify
-from utils.optimizer import build_optimizer
-from utils.lr_scheduler import build_lr_scheduler, LinearWarmUpLrScheduler
-
-# ---------------- Training engine ----------------
-from engine_pretrain import train_one_epoch
-
-
-def parse_args():
-    parser = argparse.ArgumentParser()
-    # Basic
-    parser.add_argument('--seed', type=int, default=42,
-                        help='random seed.')
-    parser.add_argument('--cuda', action='store_true', default=False,
-                        help='use cuda')
-    parser.add_argument('--batch_size', type=int, default=256,
-                        help='batch size on all GPUs')
-    parser.add_argument('--num_workers', type=int, default=4,
-                        help='number of workers')
-    parser.add_argument('--path_to_save', type=str, default='weights/',
-                        help='path to save trained model.')
-    parser.add_argument('--eval', action='store_true', default=False,
-                        help='evaluate model.')
-    # Epoch
-    parser.add_argument('--wp_epoch', type=int, default=20, 
-                        help='warmup epoch for finetune with MAE pretrained')
-    parser.add_argument('--start_epoch', type=int, default=0, 
-                        help='start epoch for finetune with MAE pretrained')
-    parser.add_argument('--eval_epoch', type=int, default=10, 
-                        help='warmup epoch for finetune with MAE pretrained')
-    parser.add_argument('--max_epoch', type=int, default=200, 
-                        help='max epoch')
-    # Dataset
-    parser.add_argument('--dataset', type=str, default='cifar10',
-                        help='dataset name')
-    parser.add_argument('--root', type=str, default='/mnt/share/ssd2/dataset',
-                        help='path to dataset folder')
-    parser.add_argument('--num_classes', type=int, default=None, 
-                        help='number of classes.')
-    # Model
-    parser.add_argument('-m', '--model', type=str, default='vit_t',
-                        help='model name')
-    parser.add_argument('--resume', default=None, type=str,
-                        help='keep training')
-    parser.add_argument('--drop_path', type=float, default=0.,
-                        help='drop_path')
-    parser.add_argument('--mask_ratio', type=float, default=0.75,
-                        help='mask ratio.')    
-    # Optimizer
-    parser.add_argument('-opt', '--optimizer', type=str, default='adamw',
-                        help='sgd, adam')
-    parser.add_argument('-wd', '--weight_decay', type=float, default=0.05,
-                        help='weight decay')
-    parser.add_argument('--base_lr', type=float, default=0.00015,
-                        help='learning rate for training model')
-    parser.add_argument('--min_lr', type=float, default=0,
-                        help='the final lr')
-    # Optimizer
-    parser.add_argument('-lrs', '--lr_scheduler', type=str, default='cosine',
-                        help='step, cosine')
-
-    return parser.parse_args()
-
-    
-def main():
-    args = parse_args()
-    # set random seed
-    setup_seed(args.seed)
-
-    # Path to save model
-    path_to_save = os.path.join(args.path_to_save, args.dataset, "pretrained", args.model)
-    os.makedirs(path_to_save, exist_ok=True)
-    args.output_dir = path_to_save
-    
-    # ------------------------- Build CUDA -------------------------
-    if args.cuda:
-        if torch.cuda.is_available():
-            cudnn.benchmark = True
-            device = torch.device("cuda")
-        else:
-            print('There is no available GPU.')
-            args.cuda = False
-            device = torch.device("cpu")
-    else:
-        device = torch.device("cpu")
-
-    # ------------------------- Build Dataset -------------------------
-    train_dataset = build_dataset(args, is_train=True)
-
-    # ------------------------- Build Dataloader -------------------------
-    train_dataloader = build_dataloader(args, train_dataset, is_train=True)
-    print('=================== Dataset Information ===================')
-    print('Train dataset size : {}'.format(len(train_dataset)))
-
-   # ------------------------- Build Model -------------------------
-    model = build_model(args, model_type='mae')
-    model.train().to(device)
-    print(model)
-
-    # ------------------------- Build Optimzier -------------------------
-    optimizer = build_optimizer(args, model)
-
-    # ------------------------- Build Lr Scheduler -------------------------
-    lr_scheduler_warmup = LinearWarmUpLrScheduler(args.base_lr, wp_iter=args.wp_epoch * len(train_dataloader))
-    lr_scheduler = build_lr_scheduler(args, optimizer)
-
-    # ------------------------- Build checkpoint -------------------------
-    load_model(args, model, optimizer, lr_scheduler)
-
-    # ------------------------- Eval before Train Pipeline -------------------------
-    if args.eval:
-        print('visualizing ...')
-        visualize(args, device, model)
-        return
-
-    # ------------------------- Training Pipeline -------------------------
-    start_time = time.time()
-    print("=================== Start training for {} epochs ===================".format(args.max_epoch))
-    for epoch in range(args.start_epoch, args.max_epoch):
-        # Train one epoch
-        train_one_epoch(args, device, model, train_dataloader,
-                        optimizer, epoch, lr_scheduler_warmup)
-
-        # LR scheduler
-        if (epoch + 1) > args.wp_epoch:
-            lr_scheduler.step()
-
-        # Evaluate
-        if epoch % args.eval_epoch == 0 or epoch + 1 == args.max_epoch:
-            print('- saving the model after {} epochs ...'.format(epoch))
-            save_model(args, epoch, model, optimizer, lr_scheduler, mae_task=True)
-
-    total_time = time.time() - start_time
-    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
-    print('Training time {}'.format(total_time_str))
-
-def visualize(args, device, model):
-    # test dataset
-    val_dataset = build_dataset(args, is_train=False)
-    val_dataloader = build_dataloader(args, val_dataset, is_train=False)
-
-    # save path
-    save_path = "vis_results/{}/{}".format(args.dataset, args.model)
-    os.makedirs(save_path, exist_ok=True)
-
-    # switch to evaluate mode
-    model.eval()
-    patch_size = args.patch_size
-    pixel_mean = val_dataloader.dataset.pixel_mean
-    pixel_std  = val_dataloader.dataset.pixel_std
-
-    with torch.no_grad():
-        for i, (images, target) in enumerate(val_dataloader):
-            # To device
-            images = images.to(device, non_blocking=True)
-            target = target.to(device, non_blocking=True)
-
-            # Inference
-            output = model(images)
-
-            # Denormalize input image
-            org_img = images[0].permute(1, 2, 0).cpu().numpy()
-            org_img = (org_img * pixel_std + pixel_mean) * 255.
-            org_img = org_img.astype(np.uint8)
-
-            # 调整mask的格式:[B, H*W] -> [B, H*W, p*p*3]
-            mask = output['mask'].unsqueeze(-1).repeat(1, 1, patch_size**2 *3)  # [B, H*W] -> [B, H*W, p*p*3]
-            # 将序列格式的mask逆转回二维图像格式
-            mask = unpatchify(mask, patch_size)
-            mask = mask[0].permute(1, 2, 0).cpu().numpy()
-            # 掩盖图像中被遮掩的图像patch区域
-            masked_img = org_img * (1 - mask)  # 1 is removing, 0 is keeping
-            masked_img = masked_img.astype(np.uint8)
-
-            # 将序列格式的重构图像逆转回二维图像格式
-            pred_img = unpatchify(output['x_pred'], patch_size)
-            pred_img = pred_img[0].permute(1, 2, 0).cpu().numpy()
-            pred_img = (pred_img * pixel_std + pixel_mean) * 255.
-            # 将原图中被保留的图像patch和网络预测的重构的图像patch拼在一起
-            pred_img = org_img * (1 - mask) + pred_img * mask
-            pred_img = pred_img.astype(np.uint8)
-
-            # visualize
-            vis_image = np.concatenate([masked_img, org_img, pred_img], axis=1)
-            vis_image = vis_image[..., (2, 1, 0)]
-            cv2.imshow('masked | origin | reconstruct ', vis_image)
-            cv2.waitKey(0)
-
-            # save
-            cv2.imwrite('{}/{:06}.png'.format(save_path, i), vis_image)
-
-
-if __name__ == "__main__":
-    main()

+ 0 - 9
masked_image_modeling/models/__init__.py

@@ -1,9 +0,0 @@
-from .vit.build import build_vision_transformer
-
-
-def build_model(args, model_type='default'):
-    # ----------- Vision Transformer -----------
-    if "vit" in args.model:
-        return build_vision_transformer(args, model_type)
-    else:
-        raise NotImplementedError("Unknown model: {}".format(args.model))

+ 0 - 3
masked_image_modeling/models/vit/__init__.py

@@ -1,3 +0,0 @@
-from .vit import build_vit
-from .vit_mae import build_vit_mae
-from .vit_cls import ViTForImageClassification

+ 0 - 45
masked_image_modeling/models/vit/build.py

@@ -1,45 +0,0 @@
-import os
-import torch
-
-from .vit     import build_vit
-from .vit_mae import build_vit_mae
-from .vit_cls import ViTForImageClassification
-
-
-def build_vision_transformer(args, model_type='default'):
-    assert args.model in ['vit_t', 'vit_s', 'vit_b', 'vit_l', 'vit_h'], "Unknown vit model: {}".format(args.model)
-
-    # ----------- Masked Image Modeling task -----------
-    if model_type == 'mae':
-        model = build_vit_mae(args.model, args.img_size, args.patch_size, args.img_dim, args.mask_ratio)
-    
-    # ----------- Image Classification task -----------
-    elif model_type == 'cls':
-        image_encoder = build_vit(args.model, args.img_size, args.patch_size, args.img_dim)
-        model = ViTForImageClassification(image_encoder, num_classes=args.num_classes, qkv_bias=True)
-        load_mae_pretrained(model.encoder, args.pretrained)
-
-    # ----------- Vison Backbone -----------
-    elif model_type == 'default':
-        model = build_vit(args.model, args.img_size, args.patch_size, args.img_dim)
-        load_mae_pretrained(model, args.pretrained)
-        
-    else:
-        raise NotImplementedError("Unknown model type: {}".format(model_type))
-    
-    return model
-
-
-def load_mae_pretrained(model, ckpt=None):
-    if ckpt is not None:
-        # check path
-        if not os.path.exists(ckpt):
-            print("No pretrained model.")
-            return model
-        print('- Loading pretrained from: {}'.format(ckpt))
-        checkpoint = torch.load(ckpt, map_location='cpu')
-        # checkpoint state dict
-        encoder_state_dict = checkpoint.pop("encoder")
-
-        # load encoder weight into ViT's encoder
-        model.load_state_dict(encoder_state_dict)

+ 0 - 186
masked_image_modeling/models/vit/modules.py

@@ -1,186 +0,0 @@
-# --------------------------------------------------------------------
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------------------
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from typing import Type
-
-
-# ----------------------- Basic modules -----------------------
-class FeedFroward(nn.Module):
-    def __init__(self,
-                 embedding_dim: int,
-                 mlp_dim: int,
-                 act: Type[nn.Module] = nn.GELU,
-                 dropout: float = 0.0,
-                 ) -> None:
-        super().__init__()
-        self.fc1   = nn.Linear(embedding_dim, mlp_dim)
-        self.drop1 = nn.Dropout(dropout)
-        self.fc2   = nn.Linear(mlp_dim, embedding_dim)
-        self.drop2 = nn.Dropout(dropout)
-        self.act   = act()
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        x = self.fc1(x)
-        x = self.act(x)
-        x = self.drop1(x)
-        x = self.fc2(x)
-        x = self.drop2(x)
-        return x
-
-class PatchEmbed(nn.Module):
-    def __init__(self,
-                 in_chans    : int = 3,
-                 embed_dim   : int = 768,
-                 kernel_size : int = 16,
-                 padding     : int = 0,
-                 stride      : int = 16,
-                 ) -> None:
-        super().__init__()
-        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        return self.proj(x)
-
-
-# ----------------------- Model modules -----------------------
-class ViTBlock(nn.Module):
-    def __init__(self,
-                 dim       :int,
-                 num_heads :int,
-                 mlp_ratio :float = 4.0,
-                 qkv_bias  :bool = True,
-                 act_layer :Type[nn.Module] = nn.GELU,
-                 dropout   :float = 0.
-                 ) -> None:
-        super().__init__()
-        # -------------- Model parameters --------------
-        self.norm1 = nn.LayerNorm(dim)
-        self.attn  = Attention(dim         = dim,
-                               qkv_bias    = qkv_bias,
-                               num_heads   = num_heads,
-                               dropout     = dropout
-                               )
-        self.norm2 = nn.LayerNorm(dim)
-        self.ffn   = FeedFroward(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        shortcut = x
-        # Attention (with prenorm)
-        x = self.norm1(x)
-        x = self.attn(x)
-        x = shortcut + x
-
-        # Feedforward (with prenorm)
-        x = x + self.ffn(self.norm2(x))
-
-        return x
-
-class Attention(nn.Module):
-    def __init__(self,
-                 dim       :int,
-                 qkv_bias  :bool  = False,
-                 num_heads :int   = 8,
-                 dropout   :float = 0.
-                 ):
-        super().__init__()
-        # --------------- Basic parameters ---------------
-        self.dim = dim
-        self.num_heads = num_heads
-        self.head_dim = dim // num_heads
-        self.scale = self.head_dim ** -0.5
-
-        # --------------- Network parameters ---------------
-        self.qkv_proj = nn.Linear(dim, dim*3, bias = qkv_bias)
-        self.attn_drop = nn.Dropout(dropout)
-        self.proj = nn.Linear(dim, dim)
-        self.proj_drop = nn.Dropout(dropout)
-
-    def forward(self, x):
-        bs, N, _ = x.shape
-        # ----------------- Input proj -----------------
-        qkv = self.qkv_proj(x)
-        q, k, v = torch.chunk(qkv, 3, dim=-1)
-
-        # ----------------- Multi-head Attn -----------------
-        ## [B, N, C] -> [B, N, H, C_h] -> [B, H, N, C_h]
-        q = q.view(bs, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
-        k = k.view(bs, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
-        v = v.view(bs, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3).contiguous()
-        ## [B, H, Nq, C_h] X [B, H, C_h, Nk] = [B, H, Nq, Nk]
-        attn = q * self.scale @ k.transpose(-1, -2)
-        attn = attn.softmax(dim=-1)
-        attn = self.attn_drop(attn)
-        x = attn @ v # [B, H, Nq, C_h]
-
-        # ----------------- Output -----------------
-        x = x.permute(0, 2, 1, 3).contiguous().view(bs, N, -1)
-        x = self.proj(x)
-        x = self.proj_drop(x)
-
-        return x
-
-
-# ----------------------- Classifier -----------------------
-class AttentionPoolingClassifier(nn.Module):
-    def __init__(
-        self,
-        in_dim      : int,
-        out_dim     : int,
-        num_heads   : int = 12,
-        qkv_bias    : bool = False,
-        num_queries : int = 1,
-    ):
-        super().__init__()
-        self.num_heads = num_heads
-        head_dim = in_dim // num_heads
-        self.scale = head_dim**-0.5
-
-        self.k = nn.Linear(in_dim, in_dim, bias=qkv_bias)
-        self.v = nn.Linear(in_dim, in_dim, bias=qkv_bias)
-
-        self.cls_token = nn.Parameter(torch.randn(1, num_queries, in_dim) * 0.02)
-        self.linear = nn.Linear(in_dim, out_dim)
-        self.bn = nn.BatchNorm1d(in_dim, affine=False, eps=1e-6)
-
-        self.num_queries = num_queries
-
-    def forward(self, x: torch.Tensor):
-        B, N, C = x.shape
-
-        x = self.bn(x.transpose(-2, -1)).transpose(-2, -1)
-        cls_token = self.cls_token.expand(B, -1, -1)  # newly created class token
-
-        q = cls_token.reshape(
-            B, self.num_queries, self.num_heads, C // self.num_heads
-        ).permute(0, 2, 1, 3)
-        k = (
-            self.k(x)
-            .reshape(B, N, self.num_heads, C // self.num_heads)
-            .permute(0, 2, 1, 3)
-        )
-
-        q = q * self.scale
-        v = (
-            self.v(x)
-            .reshape(B, N, self.num_heads, C // self.num_heads)
-            .permute(0, 2, 1, 3)
-        )
-
-        attn = q @ k.transpose(-2, -1)
-        attn = attn.softmax(dim=-1)
-
-        x_cls = (attn @ v).transpose(1, 2).reshape(B, self.num_queries, C)
-        x_cls = x_cls.mean(dim=1)
-
-        out = self.linear(x_cls)
-
-        return out, x_cls

+ 0 - 96
masked_image_modeling/models/vit/pos_embed.py

@@ -1,96 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------
-# Position embedding utils
-# --------------------------------------------------------
-
-import numpy as np
-
-import torch
-
-# --------------------------------------------------------
-# 2D sine-cosine position embedding
-# References:
-# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
-# MoCo v3: https://github.com/facebookresearch/moco-v3
-# --------------------------------------------------------
-def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
-    """
-    grid_size: int of the grid height and width
-    return:
-    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
-    """
-    grid_h = np.arange(grid_size, dtype=np.float32)
-    grid_w = np.arange(grid_size, dtype=np.float32)
-    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
-    grid = np.stack(grid, axis=0)
-
-    grid = grid.reshape([2, 1, grid_size, grid_size])
-    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
-    if cls_token:
-        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
-    return pos_embed
-
-
-def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
-    assert embed_dim % 2 == 0
-
-    # use half of dimensions to encode grid_h
-    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
-    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)
-
-    emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
-    return emb
-
-
-def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
-    """
-    embed_dim: output dimension for each position
-    pos: a list of positions to be encoded: size (M,)
-    out: (M, D)
-    """
-    assert embed_dim % 2 == 0
-    omega = np.arange(embed_dim // 2, dtype=np.float)
-    omega /= embed_dim / 2.
-    omega = 1. / 10000**omega  # (D/2,)
-
-    pos = pos.reshape(-1)  # (M,)
-    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product
-
-    emb_sin = np.sin(out) # (M, D/2)
-    emb_cos = np.cos(out) # (M, D/2)
-
-    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
-    return emb
-
-
-# --------------------------------------------------------
-# Interpolate position embeddings for high-resolution
-# References:
-# DeiT: https://github.com/facebookresearch/deit
-# --------------------------------------------------------
-def interpolate_pos_embed(model, checkpoint_model):
-    if 'pos_embed' in checkpoint_model:
-        pos_embed_checkpoint = checkpoint_model['pos_embed']
-        embedding_size = pos_embed_checkpoint.shape[-1]
-        num_patches = model.num_patches
-        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
-        # height (== width) for the checkpoint position embedding
-        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
-        # height (== width) for the new position embedding
-        new_size = int(num_patches ** 0.5)
-        # class_token and dist_token are kept unchanged
-        if orig_size != new_size:
-            print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
-            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
-            # only the position tokens are interpolated
-            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
-            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
-            pos_tokens = torch.nn.functional.interpolate(
-                pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
-            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
-            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
-            checkpoint_model['pos_embed'] = new_pos_embed

+ 0 - 180
masked_image_modeling/models/vit/vit.py

@@ -1,180 +0,0 @@
-# --------------------------------------------------------------------
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-# --------------------------------------------------------------------
-
-import torch
-import torch.nn as nn
-
-try:
-    from .modules import PatchEmbed, ViTBlock
-except:
-    from  modules import PatchEmbed, ViTBlock
-
-
-# ---------------------- Vision transformer ----------------------
-class ImageEncoderViT(nn.Module):
-    def __init__(self,
-                 img_size: int,
-                 patch_size: int,
-                 in_chans: int,
-                 patch_embed_dim: int,
-                 depth: int,
-                 num_heads: int,
-                 mlp_ratio: float,
-                 act_layer: nn.GELU,
-                 dropout: float = 0.0,
-                 ) -> None:
-        super().__init__()
-        # ----------- Basic parameters -----------
-        self.img_size = img_size
-        self.patch_size = patch_size
-        self.image_embedding_size = img_size // ((patch_size if patch_size > 0 else 1))
-        self.patch_embed_dim = patch_embed_dim
-        self.num_heads = num_heads
-        self.num_patches = (img_size // patch_size) ** 2
-        # ----------- Model parameters -----------
-        self.patch_embed = PatchEmbed(in_chans, patch_embed_dim, patch_size, stride=patch_size)
-        self.pos_embed   = nn.Parameter(torch.zeros(1, self.num_patches, patch_embed_dim))
-        self.norm_layer  = nn.LayerNorm(patch_embed_dim)
-        self.blocks      = nn.ModuleList([
-            ViTBlock(patch_embed_dim, num_heads, mlp_ratio, True, act_layer, dropout)
-            for _ in range(depth)])
-
-        self._init_weights()
-
-    def _init_weights(self):
-        # initialize (and freeze) pos_embed by sin-cos embedding
-        pos_embed = self.get_posembed(self.pos_embed.shape[-1], int(self.num_patches**.5))
-        self.pos_embed.data.copy_(pos_embed)
-
-        # initialize nn.Linear and nn.LayerNorm
-        for m in self.modules():           
-            if isinstance(m, nn.Linear):
-                # we use xavier_uniform following official JAX ViT:
-                torch.nn.init.xavier_uniform_(m.weight)
-                if isinstance(m, nn.Linear) and m.bias is not None:
-                    nn.init.constant_(m.bias, 0)
-            elif isinstance(m, nn.LayerNorm):
-                nn.init.constant_(m.bias, 0)
-                nn.init.constant_(m.weight, 1.0)
-
-    def get_posembed(self, embed_dim, grid_size, temperature=10000):
-        scale = 2 * torch.pi
-        grid_h, grid_w = grid_size, grid_size
-        num_pos_feats = embed_dim // 2
-        # get grid
-        y_embed, x_embed = torch.meshgrid([torch.arange(grid_h, dtype=torch.float32),
-                                           torch.arange(grid_w, dtype=torch.float32)])
-        # normalize grid coords
-        y_embed = y_embed / (grid_h + 1e-6) * scale
-        x_embed = x_embed / (grid_w + 1e-6) * scale
-    
-        dim_t = torch.arange(num_pos_feats, dtype=torch.float32)
-        dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
-        dim_t = temperature ** (2 * dim_t_)
-
-        pos_x = torch.div(x_embed[..., None], dim_t)
-        pos_y = torch.div(y_embed[..., None], dim_t)
-        pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
-        pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
-
-        # [H, W, C] -> [N, C]
-        pos_embed = torch.cat((pos_y, pos_x), dim=-1).view(-1, embed_dim)
-
-        return pos_embed.unsqueeze(0)
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        # Patch embed
-        x = self.patch_embed(x)
-        x = x.flatten(2).permute(0, 2, 1).contiguous()
-
-        # Add pos embed
-        x = x + self.pos_embed
-
-        # Apply Transformer blocks
-        for block in self.blocks:
-            x = block(x)
-        x = self.norm_layer(x)
-
-        return x
-
-
-# ------------------------ Model Functions ------------------------
-def build_vit(model_name="vit_t", img_size=224, patch_size=16, img_dim=3):
-    if model_name == "vit_t":
-        return ImageEncoderViT(img_size=img_size,
-                               patch_size=patch_size,
-                               in_chans=img_dim,
-                               patch_embed_dim=192,
-                               depth=12,
-                               num_heads=3,
-                               mlp_ratio=4.0,
-                               act_layer=nn.GELU,
-                               dropout = 0.1)
-    if model_name == "vit_s":
-        return ImageEncoderViT(img_size=img_size,
-                               patch_size=patch_size,
-                               in_chans=img_dim,
-                               patch_embed_dim=384,
-                               depth=12,
-                               num_heads=6,
-                               mlp_ratio=4.0,
-                               act_layer=nn.GELU,
-                               dropout = 0.1)
-    if model_name == "vit_b":
-        return ImageEncoderViT(img_size=img_size,
-                               patch_size=patch_size,
-                               in_chans=img_dim,
-                               patch_embed_dim=768,
-                               depth=12,
-                               num_heads=12,
-                               mlp_ratio=4.0,
-                               act_layer=nn.GELU,
-                               dropout = 0.1)
-    if model_name == "vit_l":
-        return ImageEncoderViT(img_size=img_size,
-                               patch_size=patch_size,
-                               in_chans=img_dim,
-                               patch_embed_dim=1024,
-                               depth=24,
-                               num_heads=16,
-                               mlp_ratio=4.0,
-                               act_layer=nn.GELU,
-                               dropout = 0.1)
-    if model_name == "vit_h":
-        return ImageEncoderViT(img_size=img_size,
-                               patch_size=patch_size,
-                               in_chans=img_dim,
-                               patch_embed_dim=1280,
-                               depth=32,
-                               num_heads=16,
-                               mlp_ratio=4.0,
-                               act_layer=nn.GELU,
-                               dropout = 0.1)
-    
-
-if __name__ == '__main__':
-    import torch
-    from thop import profile
-
-    # Prepare an image as the input
-    bs, c, h, w = 2, 3, 224, 224
-    x = torch.randn(bs, c, h, w)
-    patch_size = 16
-
-    # Build model
-    model = build_vit(patch_size=patch_size)
-
-    # Inference
-    outputs = model(x)
-
-    # Compute FLOPs & Params
-    print('==============================')
-    model.eval()
-    flops, params = profile(model, inputs=(x, ), verbose=False)
-    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))

+ 0 - 28
masked_image_modeling/models/vit/vit_cls.py

@@ -1,28 +0,0 @@
-import torch.nn as nn
-
-from .modules import AttentionPoolingClassifier
-from .vit     import ImageEncoderViT
-
-
-class ViTForImageClassification(nn.Module):
-    def __init__(self,
-                 image_encoder :ImageEncoderViT,
-                 num_classes   :int   = 1000,
-                 qkv_bias      :bool  = True,
-                 ):
-        super().__init__()
-        # -------- Model parameters --------
-        self.encoder    = image_encoder
-        self.classifier = AttentionPoolingClassifier(
-            image_encoder.patch_embed_dim, num_classes, image_encoder.num_heads, qkv_bias, num_queries=1)
-
-    def forward(self, x):
-        """
-        Inputs:
-            x: (torch.Tensor) -> [B, C, H, W]. Input image.
-        """
-        x = self.encoder(x)
-        x, x_cls = self.classifier(x)
-
-        return x
-

+ 0 - 399
masked_image_modeling/models/vit/vit_mae.py

@@ -1,399 +0,0 @@
-import math
-import torch
-import torch.nn as nn
-
-try:
-    from .modules import ViTBlock, PatchEmbed
-except:
-    from  modules import ViTBlock, PatchEmbed
-
-
-# ------------------------ Basic Modules ------------------------
-class MaeEncoder(nn.Module):
-    def __init__(self,
-                 img_size: int,
-                 patch_size: int,
-                 in_chans: int,
-                 patch_embed_dim: int,
-                 depth: int,
-                 num_heads: int,
-                 mlp_ratio: float,
-                 act_layer: nn.GELU,
-                 mask_ratio: float = 0.75,
-                 dropout: float = 0.0,
-                 ) -> None:
-        super().__init__()
-        # ----------- Basic parameters -----------
-        self.img_size = img_size
-        self.patch_size = patch_size
-        self.image_embedding_size = img_size // ((patch_size if patch_size > 0 else 1))
-        self.patch_embed_dim = patch_embed_dim
-        self.num_heads = num_heads
-        self.num_patches = (img_size // patch_size) ** 2
-        self.mask_ratio = mask_ratio
-        # ----------- Model parameters -----------
-        self.patch_embed = PatchEmbed(in_chans, patch_embed_dim, patch_size, 0, patch_size)
-        self.pos_embed   = nn.Parameter(torch.zeros(1, self.num_patches, patch_embed_dim), requires_grad=False)
-        self.norm_layer  = nn.LayerNorm(patch_embed_dim)
-        self.blocks      = nn.ModuleList([
-            ViTBlock(patch_embed_dim, num_heads, mlp_ratio, True, act_layer=act_layer, dropout=dropout)
-            for _ in range(depth)])
-        self._init_weights()
-
-    def _init_weights(self):
-        # initialize (and freeze) pos_embed by sin-cos embedding
-        pos_embed = self.get_posembed(self.pos_embed.shape[-1], int(self.num_patches**.5))
-        self.pos_embed.data.copy_(pos_embed)
-
-        # initialize nn.Linear and nn.LayerNorm
-        for m in self.modules():           
-            if isinstance(m, nn.Linear):
-                # we use xavier_uniform following official JAX ViT:
-                torch.nn.init.xavier_uniform_(m.weight)
-                if isinstance(m, nn.Linear) and m.bias is not None:
-                    nn.init.constant_(m.bias, 0)
-            elif isinstance(m, nn.LayerNorm):
-                nn.init.constant_(m.bias, 0)
-                nn.init.constant_(m.weight, 1.0)
-
-    def get_posembed(self, embed_dim, grid_size, temperature=10000):
-        scale = 2 * math.pi
-        grid_h, grid_w = grid_size, grid_size
-        num_pos_feats = embed_dim // 2
-        # get grid
-        y_embed, x_embed = torch.meshgrid([torch.arange(grid_h, dtype=torch.float32),
-                                           torch.arange(grid_w, dtype=torch.float32)])
-        # normalize grid coords
-        y_embed = y_embed / (grid_h + 1e-6) * scale
-        x_embed = x_embed / (grid_w + 1e-6) * scale
-    
-        dim_t = torch.arange(num_pos_feats, dtype=torch.float32)
-        dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
-        dim_t = temperature ** (2 * dim_t_)
-
-        pos_x = torch.div(x_embed[..., None], dim_t)
-        pos_y = torch.div(y_embed[..., None], dim_t)
-        pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
-        pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
-
-        # [H, W, C] -> [N, C]
-        pos_embed = torch.cat((pos_y, pos_x), dim=-1).view(-1, embed_dim)
-
-        return pos_embed.unsqueeze(0)
-
-    def random_masking(self, x):
-        B, N, C = x.shape
-        len_keep = int(N * (1 - self.mask_ratio))
-
-        noise = torch.rand(B, N, device=x.device)  # noise in [0, 1]
-
-        # sort noise for each sample
-        ids_shuffle = torch.argsort(noise, dim=1)        # ascend: small is keep, large is remove
-        ids_restore = torch.argsort(ids_shuffle, dim=1)  # restore the original position of each patch
-
-        # keep the first subset
-        ids_keep = ids_shuffle[:, :len_keep]
-        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, C))
-
-        # generate the binary mask: 0 is keep, 1 is remove
-        mask = torch.ones([B, N], device=x.device)
-        mask[:, :len_keep] = 0
-
-        # unshuffle to get th binary mask
-        mask = torch.gather(mask, dim=1, index=ids_restore)
-
-        return x_masked, mask, ids_restore
-    
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        # patch embed
-        x = self.patch_embed(x)
-        # [B, C, H, W] -> [B, C, N] -> [B, N, C], N = H x W
-        x = x.flatten(2).permute(0, 2, 1).contiguous()
-
-        # add pos embed
-        x = x + self.pos_embed
-
-        # masking: length -> length * mask_ratio
-        x, mask, ids_restore = self.random_masking(x)
-
-        # apply Transformer blocks
-        for block in self.blocks:
-            x = block(x)
-        x = self.norm_layer(x)
-        
-        return x, mask, ids_restore
-
-class MaeDecoder(nn.Module):
-    def __init__(self,
-                 img_dim       :int   = 3,
-                 img_size      :int   = 16,
-                 patch_size    :int   = 16,
-                 en_emb_dim    :int   = 784,
-                 de_emb_dim    :int   = 512,
-                 de_num_layers :int   = 12,
-                 de_num_heads  :int   = 12,
-                 qkv_bias      :bool  = True,
-                 mlp_ratio     :float = 4.0,
-                 dropout       :float = 0.1,
-                 mask_ratio    :float = 0.75,
-                 ):
-        super().__init__()
-        # -------- basic parameters --------
-        self.img_size = img_size
-        self.patch_size = patch_size
-        self.num_patches = (img_size // patch_size) ** 2
-        self.en_emb_dim = en_emb_dim
-        self.de_emb_dim = de_emb_dim
-        self.de_num_layers = de_num_layers
-        self.de_num_heads = de_num_heads
-        self.mask_ratio = mask_ratio
-        # -------- network parameters --------
-        self.mask_token        = nn.Parameter(torch.zeros(1, 1, de_emb_dim))
-        self.decoder_embed     = nn.Linear(en_emb_dim, de_emb_dim)
-        self.mask_token        = nn.Parameter(torch.zeros(1, 1, de_emb_dim))
-        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, de_emb_dim), requires_grad=False)  # fixed sin-cos embedding
-        self.decoder_norm      = nn.LayerNorm(de_emb_dim)
-        self.decoder_pred      = nn.Linear(de_emb_dim, patch_size**2 * img_dim, bias=True)
-        self.blocks            = nn.ModuleList([
-            ViTBlock(de_emb_dim, de_num_heads, mlp_ratio, qkv_bias, dropout=dropout)
-            for _ in range(de_num_layers)])
-        
-        self._init_weights()
-
-    def _init_weights(self):
-        # initialize (and freeze) pos_embed by sin-cos embedding
-        decoder_pos_embed = self.get_posembed(self.decoder_pos_embed.shape[-1], int(self.num_patches**.5))
-        self.decoder_pos_embed.data.copy_(decoder_pos_embed)
-
-        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
-        torch.nn.init.normal_(self.mask_token, std=.02)
-
-        # initialize nn.Linear and nn.LayerNorm
-        for m in self.modules():           
-            if isinstance(m, nn.Linear):
-                # we use xavier_uniform following official JAX ViT:
-                torch.nn.init.xavier_uniform_(m.weight)
-                if isinstance(m, nn.Linear) and m.bias is not None:
-                    nn.init.constant_(m.bias, 0)
-            elif isinstance(m, nn.LayerNorm):
-                nn.init.constant_(m.bias, 0)
-                nn.init.constant_(m.weight, 1.0)
-
-    def get_posembed(self, embed_dim, grid_size, temperature=10000):
-        scale = 2 * math.pi
-        grid_h, grid_w = grid_size, grid_size
-        num_pos_feats = embed_dim // 2
-        # get grid
-        y_embed, x_embed = torch.meshgrid([torch.arange(grid_h, dtype=torch.float32),
-                                           torch.arange(grid_w, dtype=torch.float32)])
-        # normalize grid coords
-        y_embed = y_embed / (grid_h + 1e-6) * scale
-        x_embed = x_embed / (grid_w + 1e-6) * scale
-    
-        dim_t = torch.arange(num_pos_feats, dtype=torch.float32)
-        dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
-        dim_t = temperature ** (2 * dim_t_)
-
-        pos_x = torch.div(x_embed[..., None], dim_t)
-        pos_y = torch.div(y_embed[..., None], dim_t)
-        pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
-        pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
-
-        # [H, W, C] -> [N, C]
-        pos_embed = torch.cat((pos_y, pos_x), dim=-1).view(-1, embed_dim)
-
-        return pos_embed.unsqueeze(0)
-
-    def forward(self, x_enc, ids_restore):
-        # embed tokens
-        x_enc = self.decoder_embed(x_enc)
-        B, N_nomask, C = x_enc.shape
-
-        # append mask tokens to sequence
-        mask_tokens = self.mask_token.repeat(B, ids_restore.shape[1] - N_nomask, 1)     # [B, N_mask, C], N_mask = (N-1) - N_nomask
-        x_all = torch.cat([x_enc, mask_tokens], dim=1)
-        x_all = torch.gather(x_all, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, C))  # unshuffle
-
-        # add pos embed
-        x_all = x_all + self.decoder_pos_embed
-
-        # apply Transformer blocks
-        for block in self.blocks:
-            x_all = block(x_all)
-        x_all = self.decoder_norm(x_all)
-
-        # predict
-        x_out = self.decoder_pred(x_all)
-
-        return x_out
-
-
-# ------------------------ MAE Vision Transformer ------------------------
-class ViTforMaskedAutoEncoder(nn.Module):
-    def __init__(self,
-                 encoder :MaeEncoder,
-                 decoder :MaeDecoder,
-                 ):
-        super().__init__()
-        self.mae_encoder = encoder
-        self.mae_decoder = decoder
-
-    def patchify(self, imgs, patch_size):
-        """
-        imgs: (B, 3, H, W)
-        x: (N, L, patch_size**2 *3)
-        """
-        p = patch_size
-        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
-
-        h = w = imgs.shape[2] // p
-        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
-        x = torch.einsum('nchpwq->nhwpqc', x)
-        x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
-
-        return x
-    
-    def unpatchify(self, x, patch_size):
-        """
-        x: (B, N, patch_size**2 *3)
-        imgs: (B, 3, H, W)
-        """
-        p = patch_size
-        h = w = int(x.shape[1]**.5)
-        assert h * w == x.shape[1]
-        
-        x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
-        x = torch.einsum('nhwpqc->nchpwq', x)
-        imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
-
-        return imgs
-
-    def compute_loss(self, x, output):
-        """
-        imgs: [B, 3, H, W]
-        pred: [B, N, C], C = p*p*3
-        mask: [B, N], 0 is keep, 1 is remove, 
-        """
-        target = self.patchify(x, self.mae_encoder.patch_size)
-        pred, mask = output["x_pred"], output["mask"]
-        loss = (pred - target) ** 2
-        loss = loss.mean(dim=-1)  # [B, N], mean loss per patch
-        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
-        
-        return loss
-
-    def forward(self, x):
-        imgs = x
-        x, mask, ids_restore = self.mae_encoder(x)
-        x = self.mae_decoder(x, ids_restore)
-        output = {
-            'x_pred': x,
-            'mask': mask
-        }
-
-        if self.training:
-            loss = self.compute_loss(imgs, output)
-            output["loss"] = loss
-
-        return output
-
-
-# ------------------------ Model Functions ------------------------
-def build_vit_mae(model_name="vit_t", img_size=224, patch_size=16, img_dim=3, mask_ratio=0.75):
-    # ---------------- MAE Encoder ----------------
-    if model_name == "vit_t":
-        encoder = MaeEncoder(img_size=img_size,
-                             patch_size=patch_size,
-                             in_chans=img_dim,
-                             patch_embed_dim=192,
-                             depth=12,
-                             num_heads=3,
-                             mlp_ratio=4.0,
-                             act_layer=nn.GELU,
-                             mask_ratio=mask_ratio,
-                             dropout = 0.1)
-    if model_name == "vit_s":
-        encoder = MaeEncoder(img_size=img_size,
-                             patch_size=patch_size,
-                             in_chans=img_dim,
-                             patch_embed_dim=384,
-                             depth=12,
-                             num_heads=6,
-                             mlp_ratio=4.0,
-                             act_layer=nn.GELU,
-                             mask_ratio=mask_ratio,
-                             dropout = 0.1)
-    if model_name == "vit_b":
-        encoder = MaeEncoder(img_size=img_size,
-                             patch_size=patch_size,
-                             in_chans=img_dim,
-                             patch_embed_dim=768,
-                             depth=12,
-                             num_heads=12,
-                             mlp_ratio=4.0,
-                             act_layer=nn.GELU,
-                             mask_ratio=mask_ratio,
-                             dropout = 0.1)
-    if model_name == "vit_l":
-        encoder = MaeEncoder(img_size=img_size,
-                             patch_size=patch_size,
-                             in_chans=img_dim,
-                             patch_embed_dim=1024,
-                             depth=24,
-                             num_heads=16,
-                             mlp_ratio=4.0,
-                             act_layer=nn.GELU,
-                             mask_ratio=mask_ratio,
-                             dropout = 0.1)
-    if model_name == "vit_h":
-        encoder = MaeEncoder(img_size=img_size,
-                             patch_size=patch_size,
-                             in_chans=img_dim,
-                             patch_embed_dim=1280,
-                             depth=32,
-                             num_heads=16,
-                             mlp_ratio=4.0,
-                             act_layer=nn.GELU,
-                             mask_ratio=mask_ratio,
-                             dropout = 0.1)
-    
-    # ---------------- MAE Decoder ----------------
-    decoder = MaeDecoder(img_dim = img_dim,
-                         img_size=img_size,
-                         patch_size=patch_size,
-                         en_emb_dim=encoder.patch_embed_dim,
-                         de_emb_dim=512,
-                         de_num_layers=8,
-                         de_num_heads=16,
-                         qkv_bias=True,
-                         mlp_ratio=4.0,
-                         mask_ratio=mask_ratio,
-                         dropout=0.1,)
-    
-    return ViTforMaskedAutoEncoder(encoder, decoder)
-
-
-if __name__ == '__main__':
-    import torch
-    from thop import profile
-
-    # Prepare an image as the input
-    bs, c, h, w = 2, 3, 224, 224
-    x = torch.randn(bs, c, h, w)
-    patch_size = 16
-
-    # Build model
-    model = build_vit_mae(patch_size=patch_size)
-
-    # Inference
-    outputs = model(x)
-    if "loss" in outputs:
-        print("Loss: ", outputs["loss"].item())
-
-    # Compute FLOPs & Params
-    print('==============================')
-    model.eval()
-    flops, params = profile(model, inputs=(x, ), verbose=False)
-    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))
-

+ 0 - 5
masked_image_modeling/requirements.txt

@@ -1,5 +0,0 @@
-torch
-torchvision
-opencv-python
-thop
-timm

+ 0 - 37
masked_image_modeling/utils/lr_scheduler.py

@@ -1,37 +0,0 @@
-import torch
-
-
-# Basic Warmup Scheduler
-class LinearWarmUpLrScheduler(object):
-    def __init__(self, base_lr=0.01, wp_iter=500, warmup_factor=0.00066667):
-        self.base_lr = base_lr
-        self.wp_iter = wp_iter
-        self.warmup_factor = warmup_factor
-
-    def set_lr(self, optimizer, cur_lr):
-        for param_group in optimizer.param_groups:
-            param_group['lr'] = cur_lr
-
-    def __call__(self, iter, optimizer):
-        # warmup
-        assert iter < self.wp_iter
-        alpha = iter / self.wp_iter
-        warmup_factor = self.warmup_factor * (1 - alpha) + alpha
-        tmp_lr = self.base_lr * warmup_factor
-        self.set_lr(optimizer, tmp_lr)
-
-
-def build_lr_scheduler(args, optimizer):
-    if args.lr_scheduler == "step":
-        lr_step = [args.max_epoch // 3, args.max_epoch // 3 * 2]
-        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_step, gamma=0.1)
-    elif args.lr_scheduler == "cosine":
-        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch - args.wp_epoch - 1, eta_min=args.min_lr)
-    else:
-        raise NotImplementedError("Unknown lr scheduler: {}".format(args.lr_scheduler))
-    
-    print("=================== LR Scheduler information ===================")
-    print("LR Scheduler: ", args.lr_scheduler)
-
-    return scheduler
-        

+ 0 - 231
masked_image_modeling/utils/misc.py

@@ -1,231 +0,0 @@
-import time
-import torch
-import numpy as np
-import random
-import datetime
-from collections import defaultdict, deque
-from pathlib import Path
-
-
-# ---------------------- Common functions ----------------------
-def setup_seed(seed=42):
-    torch.manual_seed(seed)
-    torch.cuda.manual_seed_all(seed)
-    np.random.seed(seed)
-    random.seed(seed)
-    torch.backends.cudnn.deterministic = True
-
-def accuracy(output, target, topk=(1,)):
-    """Computes the accuracy over the k top predictions for the specified values of k"""
-    with torch.no_grad():
-        maxk = max(topk)
-        batch_size = target.size(0)
-
-        _, pred = output.topk(maxk, 1, True, True)
-        pred = pred.t()
-        correct = pred.eq(target.reshape(1, -1).expand_as(pred))
-
-        res = []
-        for k in topk:
-            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
-            res.append(correct_k.mul_(100.0 / batch_size))
-        return res
-
-class SmoothedValue(object):
-    """Track a series of values and provide access to smoothed values over a
-    window or the global series average.
-    """
-    def __init__(self, window_size=20, fmt=None):
-        if fmt is None:
-            fmt = "{median:.4f} ({global_avg:.4f})"
-        self.deque = deque(maxlen=window_size)
-        self.total = 0.0
-        self.count = 0
-        self.fmt = fmt
-
-    def update(self, value, n=1):
-        self.deque.append(value)
-        self.count += n
-        self.total += value * n
-
-    @property
-    def median(self):
-        d = torch.tensor(list(self.deque))
-        return d.median().item()
-
-    @property
-    def avg(self):
-        d = torch.tensor(list(self.deque), dtype=torch.float32)
-        return d.mean().item()
-
-    @property
-    def global_avg(self):
-        return self.total / self.count
-
-    @property
-    def max(self):
-        return max(self.deque)
-
-    @property
-    def value(self):
-        return self.deque[-1]
-
-    def __str__(self):
-        return self.fmt.format(
-            median=self.median,
-            avg=self.avg,
-            global_avg=self.global_avg,
-            max=self.max,
-            value=self.value)
-
-class MetricLogger(object):
-    def __init__(self, delimiter="\t"):
-        self.meters = defaultdict(SmoothedValue)
-        self.delimiter = delimiter
-
-    def update(self, **kwargs):
-        for k, v in kwargs.items():
-            if v is None:
-                continue
-            if isinstance(v, torch.Tensor):
-                v = v.item()
-            assert isinstance(v, (float, int))
-            self.meters[k].update(v)
-
-    def __getattr__(self, attr):
-        if attr in self.meters:
-            return self.meters[attr]
-        if attr in self.__dict__:
-            return self.__dict__[attr]
-        raise AttributeError("'{}' object has no attribute '{}'".format(
-            type(self).__name__, attr))
-
-    def __str__(self):
-        loss_str = []
-        for name, meter in self.meters.items():
-            loss_str.append(
-                "{}: {}".format(name, str(meter))
-            )
-        return self.delimiter.join(loss_str)
-
-    def add_meter(self, name, meter):
-        self.meters[name] = meter
-
-    def log_every(self, iterable, print_freq, header=None):
-        i = 0
-        if not header:
-            header = ''
-        start_time = time.time()
-        end = time.time()
-        iter_time = SmoothedValue(fmt='{avg:.4f}')
-        data_time = SmoothedValue(fmt='{avg:.4f}')
-        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
-        log_msg = [
-            header,
-            '[{0' + space_fmt + '}/{1}]',
-            'eta: {eta}',
-            '{meters}',
-            'time: {time}',
-            'data: {data}'
-        ]
-        if torch.cuda.is_available():
-            log_msg.append('max mem: {memory:.0f}')
-        log_msg = self.delimiter.join(log_msg)
-        MB = 1024.0 * 1024.0
-        for obj in iterable:
-            data_time.update(time.time() - end)
-            yield obj
-            iter_time.update(time.time() - end)
-            if i % print_freq == 0 or i == len(iterable) - 1:
-                eta_seconds = iter_time.global_avg * (len(iterable) - i)
-                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
-                if torch.cuda.is_available():
-                    print(log_msg.format(
-                        i, len(iterable), eta=eta_string,
-                        meters=str(self),
-                        time=str(iter_time), data=str(data_time),
-                        memory=torch.cuda.max_memory_allocated() / MB))
-                else:
-                    print(log_msg.format(
-                        i, len(iterable), eta=eta_string,
-                        meters=str(self),
-                        time=str(iter_time), data=str(data_time)))
-            i += 1
-            end = time.time()
-        total_time = time.time() - start_time
-        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
-        print('{} Total time: {} ({:.4f} s / it)'.format(
-            header, total_time_str, total_time / len(iterable)))
-
-
-# ---------------------- Model functions ----------------------
-def load_model(args, model, optimizer, lr_scheduler):
-    if args.resume and args.resume.lower() != 'none':
-        print("=================== Load checkpoint ===================")
-        if args.resume.startswith('https'):
-            checkpoint = torch.hub.load_state_dict_from_url(
-                args.resume, map_location='cpu', check_hash=True)
-        else:
-            checkpoint = torch.load(args.resume, map_location='cpu')
-        model.load_state_dict(checkpoint['model'])
-        print("Resume checkpoint %s" % args.resume)
-        
-        if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
-            print('- Load optimizer from the checkpoint. ')
-            optimizer.load_state_dict(checkpoint['optimizer'])
-            args.start_epoch = checkpoint['epoch'] + 1
-
-        if 'lr_scheduler' in checkpoint:
-            print('- Load lr scheduler from the checkpoint. ')
-            lr_scheduler.load_state_dict(checkpoint.pop("lr_scheduler"))
-
-def save_model(args, epoch, model, optimizer, lr_scheduler, acc1=None, mae_task=False):
-    output_dir = Path(args.output_dir)
-    epoch_name = str(epoch)
-    if acc1 is not None:
-        checkpoint_paths = [output_dir / ('checkpoint-{}-Acc1-{:.2f}.pth'.format(epoch_name, acc1))]
-    else:
-        checkpoint_paths = [output_dir / ('checkpoint-{}.pth'.format(epoch_name))]
-    for checkpoint_path in checkpoint_paths:
-        to_save = {
-            'model': model.state_dict(),
-            'optimizer': optimizer.state_dict(),
-            'lr_scheduler': lr_scheduler.state_dict(),
-            'epoch': epoch,
-            'args': args,
-        }
-        if mae_task:
-            to_save['encoder'] = model.mae_encoder.state_dict()
-        torch.save(to_save, checkpoint_path)
-
-
-# ---------------------- Patch operations ----------------------
-def patchify(imgs, patch_size):
-    """
-    imgs: (B, 3, H, W)
-    x: (N, L, patch_size**2 *3)
-    """
-    p = patch_size
-    assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
-
-    h = w = imgs.shape[2] // p
-    x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
-    x = torch.einsum('nchpwq->nhwpqc', x)
-    x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
-
-    return x
-
-def unpatchify(x, patch_size):
-    """
-    x: (B, N, patch_size**2 *3)
-    imgs: (B, 3, H, W)
-    """
-    p = patch_size
-    h = w = int(x.shape[1]**.5)
-    assert h * w == x.shape[1]
-    
-    x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
-    x = torch.einsum('nhwpqc->nchpwq', x)
-    imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
-
-    return imgs

+ 0 - 25
masked_image_modeling/utils/optimizer.py

@@ -1,25 +0,0 @@
-import torch
-
-
-def build_optimizer(args, model):
-    ## learning rate
-    if args.optimizer == "adamw":
-        args.base_lr = args.base_lr / 256 * args.batch_size
-        optimizer = torch.optim.AdamW(model.parameters(),
-                                      lr=args.base_lr,
-                                      weight_decay=args.weight_decay)
-    elif args.optimizer == "sgd":
-        args.base_lr = args.base_lr / 256 * args.batch_size
-        optimizer = torch.optim.SGD(model.parameters(),
-                                    lr=args.base_lr,
-                                    momentum=0.9,
-                                    weight_decay=args.weight_decay)
-    else:
-        raise NotImplementedError("Unknown optimizer: {}".format(args.optimizer))
-
-    print("=================== Optimizer information ===================")
-    print("Optimizer: ", args.optimizer)
-    print('- base lr: ', args.base_lr)
-    print('- min  lr: ', args.min_lr)
-
-    return optimizer

+ 1 - 0
yolo/.gitignore

@@ -8,3 +8,4 @@ weights
 __pycache__
 __pycache__
 det_results
 det_results
 .vscode
 .vscode
+odlab/

+ 0 - 110
yolo/benchmark.py

@@ -1,110 +0,0 @@
-import argparse
-import time
-import torch
-
-# load transform
-from dataset.build import build_dataset, build_transform
-
-# load some utils
-from utils.misc import load_weight, compute_flops
-from config import build_config
-from models import build_model
-
-
-def parse_args():
-    parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
-    # Basic setting
-    parser.add_argument('-size', '--img_size', default=640, type=int,
-                        help='the max size of input image')
-    parser.add_argument('--cuda', action='store_true', default=False, 
-                        help='use cuda.')
-
-    # Model setting
-    parser.add_argument('-m', '--model', default='yolov1_r18', type=str,
-                        help='build yolo')
-    parser.add_argument('--weight', default=None,
-                        type=str, help='Trained state_dict file path to open')
-    parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
-                        help='fuse Conv & BN')
-
-    # Data setting
-    parser.add_argument('--root', default='D:/python_work/dataset/COCO/',
-                        help='data root')
-
-    return parser.parse_args()
-
-
-@torch.no_grad()
-def test_det(model, 
-             device, 
-             dataset,
-             transform=None
-             ):
-    # Step-1: Compute FLOPs and Params
-    compute_flops(model, cfg.test_img_size, device)
-
-    # Step-2: Compute FPS
-    num_images = 2002
-    total_time = 0
-    count = 0
-    with torch.no_grad():
-        for index in range(num_images):
-            if index % 500 == 0:
-                print('Testing image {:d}/{:d}....'.format(index+1, num_images))
-
-            # Load an image
-            image, _ = dataset.pull_image(index)
-
-            # Preprocess
-            x, _, ratio = transform(image)
-            x = x.unsqueeze(0).to(device)
-
-            # Start
-            torch.cuda.synchronize()
-            start_time = time.perf_counter()   
-
-            # Inference
-            outputs = model(x)
-
-            # End
-            torch.cuda.synchronize()
-            elapsed = time.perf_counter() - start_time
-        
-            if index > 1:
-                total_time += elapsed
-                count += 1
-
-        print('- FPS :', 1.0 / (total_time / count))
-
-if __name__ == '__main__':
-    args = parse_args()
-    # cuda
-    if args.cuda:
-        print('use cuda')
-        device = torch.device("cuda")
-    else:
-        device = torch.device("cpu")
-
-    # Model Config
-    cfg = build_config(args)
-
-    # Transform
-    transform = build_transform(cfg, is_train=False)
-
-    # Dataset
-    args.dataset = 'coco'
-    dataset = build_dataset(args, cfg, transform, is_train=False)
-
-    # Build model
-    model = build_model(args, cfg, is_val=False)
-
-    # Load trained weight
-    model = load_weight(model, args.weight, args.fuse_conv_bn, rep_conv=True)
-    model.to(device).eval()
-        
-    # Run
-    test_det(model     = model, 
-             device    = device, 
-             dataset   = dataset,
-             transform = transform,
-             )

+ 1 - 1
yolo/config/__init__.py

@@ -6,7 +6,7 @@ from .yolov5_config     import build_yolov5_config
 from .yolov5_af_config  import build_yolov5af_config
 from .yolov5_af_config  import build_yolov5af_config
 from .yolov6_config     import build_yolov6_config
 from .yolov6_config     import build_yolov6_config
 from .yolov8_config     import build_yolov8_config
 from .yolov8_config     import build_yolov8_config
-from .gelan_config      import build_gelan_config
+from .yolov9_config     import build_gelan_config
 from .rtdetr_config     import build_rtdetr_config
 from .rtdetr_config     import build_rtdetr_config
 
 
 
 

+ 0 - 0
masked_image_modeling/utils/__init__.py → yolo/config/detr_config.py


+ 0 - 0
yolo/models/gelan/README.md → yolo/config/fcos_config.py


+ 0 - 0
yolo/tools/__init__.py → yolo/config/yolof_config.py


+ 0 - 0
yolo/config/yolov10_config.py


+ 0 - 0
yolo/config/yolov11_config.py


+ 0 - 0
yolo/config/yolov4_config.py


+ 0 - 0
yolo/config/yolov7_config.py


+ 0 - 2
yolo/config/gelan_config.py → yolo/config/yolov9_config.py

@@ -106,8 +106,6 @@ class GElanBaseConfig(object):
 
 
         # ---------------- Data process config ----------------
         # ---------------- Data process config ----------------
         self.aug_type = 'yolo'
         self.aug_type = 'yolo'
-        self.box_format = 'xyxy'
-        self.normalize_coords = False
         self.mosaic_prob = 0.0
         self.mosaic_prob = 0.0
         self.mixup_prob  = 0.0
         self.mixup_prob  = 0.0
         self.copy_paste  = 0.0           # approximated by the YOLOX's mixup
         self.copy_paste  = 0.0           # approximated by the YOLOX's mixup

+ 1 - 1
yolo/models/__init__.py

@@ -9,7 +9,7 @@ from .yolov5.build     import build_yolov5
 from .yolov5_af.build  import build_yolov5af
 from .yolov5_af.build  import build_yolov5af
 from .yolov6.build     import build_yolov6
 from .yolov6.build     import build_yolov6
 from .yolov8.build     import build_yolov8
 from .yolov8.build     import build_yolov8
-from .gelan.build      import build_gelan
+from .yolov9.build     import build_gelan
 from .rtdetr.build     import build_rtdetr
 from .rtdetr.build     import build_rtdetr
 
 
 
 

+ 0 - 0
yolo/models/detr/build.py


+ 0 - 0
yolo/models/detr/detr.py


+ 0 - 0
yolo/models/detr/detr_backbone.py


+ 0 - 0
yolo/models/detr/detr_transformer.py


+ 0 - 0
yolo/models/detr/loss.py


+ 0 - 0
yolo/models/detr/matcher.py


+ 148 - 0
yolo/models/detr/modules.py

@@ -0,0 +1,148 @@
+import torch
+import torch.nn as nn
+from typing import List
+
+
+# --------------------- Basic modules ---------------------
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+        
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+    elif norm_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+
+class BasicConv(nn.Module):
+    def __init__(self, 
+                 in_dim,                   # in channels
+                 out_dim,                  # out channels 
+                 kernel_size=1,            # kernel size 
+                 padding=0,                # padding
+                 stride=1,                 # padding
+                 dilation=1,               # dilation
+                 act_type  :str = 'lrelu', # activation
+                 norm_type :str = 'BN',    # normalization
+                 depthwise :bool = False
+                ):
+        super(BasicConv, self).__init__()
+        self.depthwise = depthwise
+        use_bias = False if norm_type is not None else True
+        if not depthwise:
+            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=1, bias=use_bias)
+            self.norm = get_norm(norm_type, out_dim)
+        else:
+            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim, bias=use_bias)
+            self.norm1 = get_norm(norm_type, in_dim)
+            self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1)
+            self.norm2 = get_norm(norm_type, out_dim)
+        self.act  = get_activation(act_type)
+
+    def forward(self, x):
+        if not self.depthwise:
+            return self.act(self.norm(self.conv(x)))
+        else:
+            # Depthwise conv
+            x = self.norm1(self.conv1(x))
+            # Pointwise conv
+            x = self.act(self.norm2(self.conv2(x)))
+            return x
+
+
+# --------------------- ResNet modules ---------------------
+def conv3x3(in_planes, out_planes, stride=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=1, bias=False)
+
+def conv1x1(in_planes, out_planes, stride=1):
+    """1x1 convolution"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(BasicBlock, self).__init__()
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(Bottleneck, self).__init__()
+        self.conv1 = conv1x1(inplanes, planes)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = conv3x3(planes, planes, stride)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.conv3 = conv1x1(planes, planes * self.expansion)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out

+ 0 - 0
yolo/models/fcos/build.py


+ 0 - 0
yolo/models/fcos/fcos.py


+ 0 - 0
yolo/models/fcos/fcos_backbone.py


+ 68 - 0
yolo/models/fcos/fcos_fpn.py

@@ -0,0 +1,68 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+# ------------------ Basic Feature Pyramid Network ------------------
+class BasicFPN(nn.Module):
+    def __init__(self, cfg, 
+                 in_dims=[512, 1024, 2048],
+                 out_dim=256,
+                 ):
+        super().__init__()
+        # ------------------ Basic parameters -------------------
+        self.p6_feat = cfg.fpn_p6_feat
+        self.p7_feat = cfg.fpn_p7_feat
+        self.from_c5 = cfg.fpn_p6_from_c5
+
+        # ------------------ Network parameters -------------------
+        ## latter layers
+        self.input_projs = nn.ModuleList()
+        self.smooth_layers = nn.ModuleList()
+        for in_dim in in_dims[::-1]:
+            self.input_projs.append(nn.Conv2d(in_dim, out_dim, kernel_size=1))
+            self.smooth_layers.append(nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1))
+
+        ## P6/P7 layers
+        if self.p6_feat:
+            if self.from_c5:
+                self.p6_conv = nn.Conv2d(in_dims[-1], out_dim, kernel_size=3, stride=2, padding=1)
+            else: # from p5
+                self.p6_conv = nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1)
+        if self.p7_feat:
+            self.p7_conv = nn.Sequential(
+                nn.ReLU(inplace=True),
+                nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1)
+            )
+
+    def forward(self, feats):
+        """
+            feats: (List of Tensor) [C3, C4, C5], C_i ∈ R^(B x C_i x H_i x W_i)
+        """
+        outputs = []
+        # [C3, C4, C5] -> [C5, C4, C3]
+        feats = feats[::-1]
+        top_level_feat = feats[0]
+        prev_feat = self.input_projs[0](top_level_feat)
+        outputs.append(self.smooth_layers[0](prev_feat))
+
+        for feat, input_proj, smooth_layer in zip(feats[1:], self.input_projs[1:], self.smooth_layers[1:]):
+            feat = input_proj(feat)
+            top_down_feat = F.interpolate(prev_feat, size=feat.shape[2:], mode='nearest')
+            prev_feat = feat + top_down_feat
+            outputs.insert(0, smooth_layer(prev_feat))
+
+        if self.p6_feat:
+            if self.from_c5:
+                p6_feat = self.p6_conv(feats[0])
+            else:
+                p6_feat = self.p6_conv(outputs[-1])
+            # [P3, P4, P5] -> [P3, P4, P5, P6]
+            outputs.append(p6_feat)
+
+            if self.p7_feat:
+                p7_feat = self.p7_conv(p6_feat)
+                # [P3, P4, P5, P6] -> [P3, P4, P5, P6, P7]
+                outputs.append(p7_feat)
+
+        # [P3, P4, P5] or [P3, P4, P5, P6, P7]
+        return outputs

+ 186 - 0
yolo/models/fcos/fcos_head.py

@@ -0,0 +1,186 @@
+import torch
+import torch.nn as nn
+
+from .modules import BasicConv
+
+
+class Scale(nn.Module):
+    """
+    Multiply the output regression range by a learnable constant value
+    """
+    def __init__(self, init_value=1.0):
+        """
+        init_value : initial value for the scalar
+        """
+        super().__init__()
+        self.scale = nn.Parameter(
+            torch.tensor(init_value, dtype=torch.float32),
+            requires_grad=True
+        )
+
+    def forward(self, x):
+        """
+        input -> scale * input
+        """
+        return x * self.scale
+
+class FcosHead(nn.Module):
+    def __init__(self, cfg, in_dim, out_dim,):
+        super().__init__()
+        self.fmp_size = None
+        # ------------------ Basic parameters -------------------
+        self.cfg = cfg
+        self.in_dim = in_dim
+        self.stride       = cfg.out_stride
+        self.num_classes  = cfg.num_classes
+        self.num_cls_head = cfg.num_cls_head
+        self.num_reg_head = cfg.num_reg_head
+        self.act_type     = cfg.head_act
+        self.norm_type    = cfg.head_norm
+
+        # ------------------ Network parameters -------------------
+        ## cls head
+        cls_heads = []
+        self.cls_head_dim = out_dim
+        for i in range(self.num_cls_head):
+            if i == 0:
+                cls_heads.append(
+                    BasicConv(in_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=self.act_type, norm_type=self.norm_type)
+                              )
+            else:
+                cls_heads.append(
+                    BasicConv(self.cls_head_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=self.act_type, norm_type=self.norm_type)
+                              )
+        
+        ## reg head
+        reg_heads = []
+        self.reg_head_dim = out_dim
+        for i in range(self.num_reg_head):
+            if i == 0:
+                reg_heads.append(
+                    BasicConv(in_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=self.act_type, norm_type=self.norm_type)
+                              )
+            else:
+                reg_heads.append(
+                    BasicConv(self.reg_head_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=self.act_type, norm_type=self.norm_type)
+                              )
+        self.cls_heads = nn.Sequential(*cls_heads)
+        self.reg_heads = nn.Sequential(*reg_heads)
+
+        ## pred layers
+        self.cls_pred = nn.Conv2d(self.cls_head_dim, cfg.num_classes, kernel_size=3, padding=1)
+        self.reg_pred = nn.Conv2d(self.reg_head_dim, 4, kernel_size=3, padding=1)
+        self.ctn_pred = nn.Conv2d(self.reg_head_dim, 1, kernel_size=3, padding=1)
+        
+        ## scale layers
+        self.scales = nn.ModuleList(
+            Scale() for _ in range(len(self.stride))
+        )
+        
+        # init bias
+        self._init_layers()
+
+    def _init_layers(self):
+        for module in [self.cls_heads, self.reg_heads, self.cls_pred, self.reg_pred, self.ctn_pred]:
+            for layer in module.modules():
+                if isinstance(layer, nn.Conv2d):
+                    torch.nn.init.normal_(layer.weight, mean=0, std=0.01)
+                    if layer.bias is not None:
+                        torch.nn.init.constant_(layer.bias, 0)
+                if isinstance(layer, nn.GroupNorm):
+                    torch.nn.init.constant_(layer.weight, 1)
+                    if layer.bias is not None:
+                        torch.nn.init.constant_(layer.bias, 0)
+        # init the bias of cls pred
+        init_prob = 0.01
+        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+        torch.nn.init.constant_(self.cls_pred.bias, bias_value)
+        
+    def get_anchors(self, level, fmp_size):
+        """
+            fmp_size: (List) [H, W]
+        """
+        # generate grid cells
+        fmp_h, fmp_w = fmp_size
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        # [H, W, 2] -> [HW, 2]
+        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2) + 0.5
+        anchors *= self.stride[level]
+
+        return anchors
+        
+    def decode_boxes(self, pred_deltas, anchors):
+        """
+            pred_deltas: (List[Tensor]) [B, M, 4] or [M, 4] (l, t, r, b)
+            anchors:     (List[Tensor]) [1, M, 2] or [M, 2]
+        """
+        # x1 = x_anchor - l, x2 = x_anchor + r
+        # y1 = y_anchor - t, y2 = y_anchor + b
+        pred_x1y1 = anchors - pred_deltas[..., :2]
+        pred_x2y2 = anchors + pred_deltas[..., 2:]
+        pred_box = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
+        return pred_box
+    
+    def forward(self, pyramid_feats, mask=None):
+        all_masks = []
+        all_anchors = []
+        all_cls_preds = []
+        all_reg_preds = []
+        all_box_preds = []
+        all_ctn_preds = []
+        for level, feat in enumerate(pyramid_feats):
+            # ------------------- Decoupled head -------------------
+            cls_feat = self.cls_heads(feat)
+            reg_feat = self.reg_heads(feat)
+
+            # ------------------- Generate anchor box -------------------
+            B, _, H, W = cls_feat.size()
+            fmp_size = [H, W]
+            anchors = self.get_anchors(level, fmp_size)   # [M, 4]
+            anchors = anchors.to(cls_feat.device)
+
+            # ------------------- Predict -------------------
+            cls_pred = self.cls_pred(cls_feat)
+            reg_pred = self.reg_pred(reg_feat)
+            ctn_pred = self.ctn_pred(reg_feat)
+
+            # ------------------- Process preds -------------------
+            ## [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
+            cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+            ctn_pred = ctn_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 1)
+            reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
+            reg_pred = nn.functional.relu(self.scales[level](reg_pred)) * self.stride[level]
+            ## Decode bbox
+            box_pred = self.decode_boxes(reg_pred, anchors)
+            ## Adjust mask
+            if mask is not None:
+                # [B, H, W]
+                mask_i = torch.nn.functional.interpolate(mask[None].float(), size=[H, W]).bool()[0]
+                # [B, H, W] -> [B, M]
+                mask_i = mask_i.flatten(1)     
+                all_masks.append(mask_i)
+                
+            all_anchors.append(anchors)
+            all_cls_preds.append(cls_pred)
+            all_reg_preds.append(reg_pred)
+            all_box_preds.append(box_pred)
+            all_ctn_preds.append(ctn_pred)
+
+        outputs = {"pred_cls": all_cls_preds,  # List [B, M, C]
+                   "pred_reg": all_reg_preds,  # List [B, M, 4]
+                   "pred_box": all_box_preds,  # List [B, M, 4]
+                   "pred_ctn": all_ctn_preds,  # List [B, M, 1]
+                   "anchors": all_anchors,     # List [B, M, 2]
+                   "strides": self.stride,
+                   "mask": all_masks}          # List [B, M,]
+
+        return outputs 

+ 290 - 0
yolo/models/fcos/loss.py

@@ -0,0 +1,290 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from utils.box_ops import get_ious
+from utils.misc import sigmoid_focal_loss
+from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
+
+from .matcher import FcosMatcher, AlignedOTAMatcher
+
+
+class SetCriterion(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        # ------------- Basic parameters -------------
+        self.cfg = cfg
+        self.num_classes = cfg.num_classes
+        # ------------- Focal loss -------------
+        self.alpha = cfg.focal_loss_alpha
+        self.gamma = cfg.focal_loss_gamma
+        # ------------- Loss weight -------------
+        # ------------- Matcher & Loss weight -------------
+        self.matcher_cfg = cfg.matcher_hpy
+        if cfg.matcher == 'fcos_matcher':
+            self.weight_dict = {'loss_cls': cfg.loss_cls_weight,
+                                'loss_reg': cfg.loss_reg_weight,
+                                'loss_ctn': cfg.loss_ctn_weight}
+            self.matcher = FcosMatcher(cfg.num_classes,
+                                       self.matcher_cfg['center_sampling_radius'],
+                                       self.matcher_cfg['object_sizes_of_interest'],
+                                       [1., 1., 1., 1.]
+                                       )
+        elif cfg.matcher == 'simota':
+            self.weight_dict = {'loss_cls': cfg.loss_cls_weight,
+                                'loss_reg': cfg.loss_reg_weight}
+            self.matcher = AlignedOTAMatcher(cfg.num_classes,
+                                             self.matcher_cfg['soft_center_radius'],
+                                             self.matcher_cfg['topk_candidates'])
+        else:
+            raise NotImplementedError("Unknown matcher: {}.".format(cfg.matcher))
+
+    def loss_labels(self, pred_cls, tgt_cls, num_boxes=1.0):
+        """
+            pred_cls: (Tensor) [N, C]
+            tgt_cls:  (Tensor) [N, C]
+        """
+        # cls loss: [V, C]
+        loss_cls = sigmoid_focal_loss(pred_cls, tgt_cls, self.alpha, self.gamma)
+
+        return loss_cls.sum() / num_boxes
+
+    def loss_labels_qfl(self, pred_cls, target, beta=2.0, num_boxes=1.0):
+        # Quality FocalLoss
+        """
+            pred_cls: (torch.Tensor): [N, C]。
+            target:   (tuple([torch.Tensor], [torch.Tensor])): label -> (N,), score -> (N)
+        """
+        label, score = target
+        pred_sigmoid = pred_cls.sigmoid()
+        scale_factor = pred_sigmoid
+        zerolabel = scale_factor.new_zeros(pred_cls.shape)
+
+        ce_loss = F.binary_cross_entropy_with_logits(
+            pred_cls, zerolabel, reduction='none') * scale_factor.pow(beta)
+        
+        bg_class_ind = pred_cls.shape[-1]
+        pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
+        if pos.shape[0] > 0:
+            pos_label = label[pos].long()
+
+            scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
+
+            ce_loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
+                pred_cls[pos, pos_label], score[pos],
+                reduction='none') * scale_factor.abs().pow(beta)
+
+        return ce_loss.sum() / num_boxes
+    
+    def loss_bboxes_ltrb(self, pred_delta, tgt_delta, bbox_quality=None, num_boxes=1.0):
+        """
+            pred_box: (Tensor) [N, 4]
+            tgt_box:  (Tensor) [N, 4]
+        """
+        pred_delta = torch.cat((-pred_delta[..., :2], pred_delta[..., 2:]), dim=-1)
+        tgt_delta = torch.cat((-tgt_delta[..., :2], tgt_delta[..., 2:]), dim=-1)
+
+        eps = torch.finfo(torch.float32).eps
+
+        pred_area = (pred_delta[..., 2] - pred_delta[..., 0]).clamp_(min=0) \
+            * (pred_delta[..., 3] - pred_delta[..., 1]).clamp_(min=0)
+        tgt_area = (tgt_delta[..., 2] - tgt_delta[..., 0]).clamp_(min=0) \
+            * (tgt_delta[..., 3] - tgt_delta[..., 1]).clamp_(min=0)
+
+        w_intersect = (torch.min(pred_delta[..., 2], tgt_delta[..., 2])
+                    - torch.max(pred_delta[..., 0], tgt_delta[..., 0])).clamp_(min=0)
+        h_intersect = (torch.min(pred_delta[..., 3], tgt_delta[..., 3])
+                    - torch.max(pred_delta[..., 1], tgt_delta[..., 1])).clamp_(min=0)
+
+        area_intersect = w_intersect * h_intersect
+        area_union = tgt_area + pred_area - area_intersect
+        ious = area_intersect / area_union.clamp(min=eps)
+
+        # giou
+        g_w_intersect = torch.max(pred_delta[..., 2], tgt_delta[..., 2]) \
+            - torch.min(pred_delta[..., 0], tgt_delta[..., 0])
+        g_h_intersect = torch.max(pred_delta[..., 3], tgt_delta[..., 3]) \
+            - torch.min(pred_delta[..., 1], tgt_delta[..., 1])
+        ac_uion = g_w_intersect * g_h_intersect
+        gious = ious - (ac_uion - area_union) / ac_uion.clamp(min=eps)
+        loss_box = 1 - gious
+
+        if bbox_quality is not None:
+            loss_box = loss_box * bbox_quality.view(loss_box.size())
+
+        return loss_box.sum() / num_boxes
+
+    def loss_bboxes_xyxy(self, pred_box, gt_box, num_boxes=1.0, box_weight=None):
+        ious = get_ious(pred_box, gt_box, box_mode="xyxy", iou_type='giou')
+        loss_box = 1.0 - ious
+
+        if box_weight is not None:
+            loss_box = loss_box.squeeze(-1) * box_weight
+
+        return loss_box.sum() / num_boxes
+    
+    def fcos_loss(self, outputs, targets):
+        """
+            outputs['pred_cls']: (Tensor) [B, M, C]
+            outputs['pred_reg']: (Tensor) [B, M, 4]
+            outputs['pred_ctn']: (Tensor) [B, M, 1]
+            outputs['strides']: (List) [8, 16, 32, ...] stride of the model output
+            targets: (List) [dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}, ...]
+        """
+        # -------------------- Pre-process --------------------
+        device = outputs['pred_cls'][0].device
+        fpn_strides = outputs['strides']
+        anchors = outputs['anchors']
+        pred_cls = torch.cat(outputs['pred_cls'], dim=1).view(-1, self.num_classes)
+        pred_delta = torch.cat(outputs['pred_reg'], dim=1).view(-1, 4)
+        pred_ctn = torch.cat(outputs['pred_ctn'], dim=1).view(-1, 1)
+        masks = ~torch.cat(outputs['mask'], dim=1).view(-1)
+
+        # -------------------- Label Assignment --------------------
+        gt_classes, gt_deltas, gt_centerness = self.matcher(fpn_strides, anchors, targets)
+        gt_classes = gt_classes.flatten().to(device)
+        gt_deltas = gt_deltas.view(-1, 4).to(device)
+        gt_centerness = gt_centerness.view(-1, 1).to(device)
+
+        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
+        num_foreground = foreground_idxs.sum()
+
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_foreground)
+        num_foreground = torch.clamp(num_foreground / get_world_size(), min=1).item()
+
+        num_foreground_centerness = gt_centerness[foreground_idxs].sum()
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_foreground_centerness)
+        num_targets = torch.clamp(num_foreground_centerness / get_world_size(), min=1).item()
+
+        # -------------------- classification loss --------------------
+        gt_classes_target = torch.zeros_like(pred_cls)
+        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1
+        valid_idxs = (gt_classes >= 0) & masks
+        loss_labels = self.loss_labels(
+            pred_cls[valid_idxs], gt_classes_target[valid_idxs], num_foreground)
+
+        # -------------------- regression loss --------------------
+        loss_bboxes = self.loss_bboxes_ltrb(
+            pred_delta[foreground_idxs], gt_deltas[foreground_idxs], gt_centerness[foreground_idxs], num_targets)
+
+        # -------------------- centerness loss --------------------
+        loss_centerness = F.binary_cross_entropy_with_logits(
+            pred_ctn[foreground_idxs],  gt_centerness[foreground_idxs], reduction='none')
+        loss_centerness = loss_centerness.sum() / num_foreground
+
+        total_loss = loss_labels * self.weight_dict["loss_cls"] + \
+                     loss_bboxes * self.weight_dict["loss_reg"] + \
+                     loss_centerness * self.weight_dict["loss_ctn"]
+        loss_dict = dict(
+                loss_cls = loss_labels,
+                loss_reg = loss_bboxes,
+                loss_ctn = loss_centerness,
+                losses   = total_loss,
+        )
+
+        return loss_dict
+    
+    def ota_loss(self, outputs, targets):
+        """
+            outputs['pred_cls']: (Tensor) [B, M, C]
+            outputs['pred_reg']: (Tensor) [B, M, 4]
+            outputs['pred_box']: (Tensor) [B, M, 4]
+            outputs['strides']: (List) [8, 16, 32, ...] stride of the model output
+            targets: (List) [dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}, ...]
+        """
+        # -------------------- Pre-process --------------------
+        bs          = outputs['pred_cls'][0].shape[0]
+        device      = outputs['pred_cls'][0].device
+        fpn_strides = outputs['strides']
+        anchors     = outputs['anchors']
+        # preds: [B, M, C]
+        # preds: [B, M, C]
+        cls_preds = torch.cat(outputs['pred_cls'], dim=1)
+        box_preds = torch.cat(outputs['pred_box'], dim=1)
+        masks = ~torch.cat(outputs['mask'], dim=1).view(-1)
+
+        # -------------------- Label Assignment --------------------
+        cls_targets = []
+        box_targets = []
+        assign_metrics = []
+        for batch_idx in range(bs):
+            tgt_labels = targets[batch_idx]["labels"].to(device)  # [N,]
+            tgt_bboxes = targets[batch_idx]["boxes"].to(device)   # [N, 4]
+            # refine target
+            tgt_boxes_wh = tgt_bboxes[..., 2:] - tgt_bboxes[..., :2]
+            min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
+            keep = (min_tgt_size >= 8)
+            tgt_bboxes = tgt_bboxes[keep]
+            tgt_labels = tgt_labels[keep]
+            # label assignment
+            assigned_result = self.matcher(fpn_strides=fpn_strides,
+                                           anchors=anchors,
+                                           pred_cls=cls_preds[batch_idx].detach(),
+                                           pred_box=box_preds[batch_idx].detach(),
+                                           gt_labels=tgt_labels,
+                                           gt_bboxes=tgt_bboxes
+                                           )
+            cls_targets.append(assigned_result['assigned_labels'])
+            box_targets.append(assigned_result['assigned_bboxes'])
+            assign_metrics.append(assigned_result['assign_metrics'])
+
+        # List[B, M, C] -> Tensor[BM, C]
+        cls_targets = torch.cat(cls_targets, dim=0)
+        box_targets = torch.cat(box_targets, dim=0)
+        assign_metrics = torch.cat(assign_metrics, dim=0)
+
+        valid_idxs = (cls_targets >= 0) & masks
+        foreground_idxs = (cls_targets >= 0) & (cls_targets != self.num_classes)
+        num_fgs = assign_metrics.sum()
+
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_fgs)
+        num_fgs = torch.clamp(num_fgs / get_world_size(), min=1).item()
+
+        # -------------------- classification loss --------------------
+        cls_preds = cls_preds.view(-1, self.num_classes)[valid_idxs]
+        qfl_targets = (cls_targets[valid_idxs], assign_metrics[valid_idxs])
+        loss_labels = self.loss_labels_qfl(cls_preds, qfl_targets, 2.0, num_fgs)
+
+        # -------------------- regression loss --------------------
+        box_preds_pos = box_preds.view(-1, 4)[foreground_idxs]
+        box_targets_pos = box_targets[foreground_idxs]
+        box_weight = assign_metrics[foreground_idxs]
+        loss_bboxes = self.loss_bboxes_xyxy(box_preds_pos, box_targets_pos, num_fgs, box_weight)
+
+        total_loss = loss_labels * self.weight_dict["loss_cls"] + \
+                     loss_bboxes * self.weight_dict["loss_reg"]
+        loss_dict = dict(
+                loss_cls = loss_labels,
+                loss_reg = loss_bboxes,
+                losses   = total_loss,
+        )
+
+        return loss_dict
+    
+    def forward(self, outputs, targets):
+        """
+            outputs['pred_cls']: (Tensor) [B, M, C]
+            outputs['pred_reg']: (Tensor) [B, M, 4]
+            outputs['pred_ctn']: (Tensor) [B, M, 1]
+            outputs['strides']: (List) [8, 16, 32, ...] stride of the model output
+            targets: (List) [dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}, ...]
+        """
+        if self.cfg.matcher == "fcos_matcher":
+            return self.fcos_loss(outputs, targets)
+        elif self.cfg.matcher == "simota":
+            return self.ota_loss(outputs, targets)
+        else:
+            raise NotImplementedError
+
+
+if __name__ == "__main__":
+    pass

+ 378 - 0
yolo/models/fcos/matcher.py

@@ -0,0 +1,378 @@
+import math
+import torch
+import torch.nn.functional as F
+
+from utils.box_ops import *
+
+
+@torch.no_grad()
+def get_ious_and_iou_loss(inputs,
+                          targets,
+                          weight=None,
+                          box_mode="xyxy",
+                          loss_type="iou",
+                          reduction="none"):
+    """
+    Compute iou loss of type ['iou', 'giou', 'linear_iou']
+
+    Args:
+        inputs (tensor): pred values
+        targets (tensor): target values
+        weight (tensor): loss weight
+        box_mode (str): 'xyxy' or 'ltrb', 'ltrb' is currently supported.
+        loss_type (str): 'giou' or 'iou' or 'linear_iou'
+        reduction (str): reduction manner
+
+    Returns:
+        loss (tensor): computed iou loss.
+    """
+    if box_mode == "ltrb":
+        inputs = torch.cat((-inputs[..., :2], inputs[..., 2:]), dim=-1)
+        targets = torch.cat((-targets[..., :2], targets[..., 2:]), dim=-1)
+    elif box_mode != "xyxy":
+        raise NotImplementedError
+
+    eps = torch.finfo(torch.float32).eps
+
+    inputs_area = (inputs[..., 2] - inputs[..., 0]).clamp_(min=0) \
+        * (inputs[..., 3] - inputs[..., 1]).clamp_(min=0)
+    targets_area = (targets[..., 2] - targets[..., 0]).clamp_(min=0) \
+        * (targets[..., 3] - targets[..., 1]).clamp_(min=0)
+
+    w_intersect = (torch.min(inputs[..., 2], targets[..., 2])
+                   - torch.max(inputs[..., 0], targets[..., 0])).clamp_(min=0)
+    h_intersect = (torch.min(inputs[..., 3], targets[..., 3])
+                   - torch.max(inputs[..., 1], targets[..., 1])).clamp_(min=0)
+
+    area_intersect = w_intersect * h_intersect
+    area_union = targets_area + inputs_area - area_intersect
+    ious = area_intersect / area_union.clamp(min=eps)
+
+    if loss_type == "iou":
+        loss = -ious.clamp(min=eps).log()
+    elif loss_type == "linear_iou":
+        loss = 1 - ious
+    elif loss_type == "giou":
+        g_w_intersect = torch.max(inputs[..., 2], targets[..., 2]) \
+            - torch.min(inputs[..., 0], targets[..., 0])
+        g_h_intersect = torch.max(inputs[..., 3], targets[..., 3]) \
+            - torch.min(inputs[..., 1], targets[..., 1])
+        ac_uion = g_w_intersect * g_h_intersect
+        gious = ious - (ac_uion - area_union) / ac_uion.clamp(min=eps)
+        loss = 1 - gious
+    else:
+        raise NotImplementedError
+    if weight is not None:
+        loss = loss * weight.view(loss.size())
+        if reduction == "mean":
+            loss = loss.sum() / max(weight.sum().item(), eps)
+    else:
+        if reduction == "mean":
+            loss = loss.mean()
+    if reduction == "sum":
+        loss = loss.sum()
+
+    return ious, loss
+
+
+class FcosMatcher(object):
+    """
+        This code referenced to https://github.com/Megvii-BaseDetection/cvpods
+    """
+    def __init__(self, 
+                 num_classes,
+                 center_sampling_radius,
+                 object_sizes_of_interest,
+                 box_weights=[1, 1, 1, 1]):
+        self.num_classes = num_classes
+        self.center_sampling_radius = center_sampling_radius
+        self.object_sizes_of_interest = object_sizes_of_interest
+        self.box_weightss = box_weights
+
+
+    def get_deltas(self, anchors, boxes):
+        """
+        Get box regression transformation deltas (dl, dt, dr, db) that can be used
+        to transform the `anchors` into the `boxes`. That is, the relation
+        ``boxes == self.apply_deltas(deltas, anchors)`` is true.
+
+        Args:
+            anchors (Tensor): anchors, e.g., feature map coordinates
+            boxes (Tensor): target of the transformation, e.g., ground-truth
+                boxes.
+        """
+        assert isinstance(anchors, torch.Tensor), type(anchors)
+        assert isinstance(boxes, torch.Tensor), type(boxes)
+        deltas = torch.cat((anchors - boxes[..., :2], boxes[..., 2:] - anchors),
+                           dim=-1) * anchors.new_tensor(self.box_weightss)
+        return deltas
+
+
+    @torch.no_grad()
+    def __call__(self, fpn_strides, anchors, targets):
+        """
+            fpn_strides: (List) List[8, 16, 32, ...] stride of network output.
+            anchors: (List of Tensor) List[F, M, 2], F = num_fpn_levels
+            targets: (Dict) dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}
+        """
+        gt_classes = []
+        gt_anchors_deltas = []
+        gt_centerness = []
+        device = anchors[0].device
+
+        # List[F, M, 2] -> [M, 2]
+        anchors_over_all_feature_maps = torch.cat(anchors, dim=0).to(device)
+
+        for targets_per_image in targets:
+            # generate object_sizes_of_interest: List[[M, 2]]
+            object_sizes_of_interest = [anchors_i.new_tensor(scale_range).unsqueeze(0).expand(anchors_i.size(0), -1) 
+                                        for anchors_i, scale_range in zip(anchors, self.object_sizes_of_interest)]
+            # List[F, M, 2] -> [M, 2], M = M1 + M2 + ... + MF
+            object_sizes_of_interest = torch.cat(object_sizes_of_interest, dim=0)
+            # [N, 4]
+            tgt_box = targets_per_image['boxes'].to(device)
+            # [N, C]
+            tgt_cls = targets_per_image['labels'].to(device)
+            # [N, M, 4], M = M1 + M2 + ... + MF
+            deltas = self.get_deltas(anchors_over_all_feature_maps, tgt_box.unsqueeze(1))
+
+            has_gt = (len(tgt_cls) > 0)
+            if has_gt:
+                if self.center_sampling_radius > 0:
+                    # bbox centers: [N, 2]
+                    centers = (tgt_box[..., :2] + tgt_box[..., 2:]) * 0.5
+
+                    is_in_boxes = []
+                    for stride, anchors_i in zip(fpn_strides, anchors):
+                        radius = stride * self.center_sampling_radius
+                        # [N, 4]
+                        center_boxes = torch.cat((
+                            torch.max(centers - radius, tgt_box[:, :2]),
+                            torch.min(centers + radius, tgt_box[:, 2:]),
+                        ), dim=-1)
+                        # [N, Mi, 4]
+                        center_deltas = self.get_deltas(anchors_i, center_boxes.unsqueeze(1))
+                        # [N, Mi]
+                        is_in_boxes.append(center_deltas.min(dim=-1).values > 0)
+                    # [N, M], M = M1 + M2 + ... + MF
+                    is_in_boxes = torch.cat(is_in_boxes, dim=1)
+                else:
+                    # no center sampling, it will use all the locations within a ground-truth box
+                    # [N, M], M = M1 + M2 + ... + MF
+                    is_in_boxes = deltas.min(dim=-1).values > 0
+                # [N, M], M = M1 + M2 + ... + MF
+                max_deltas = deltas.max(dim=-1).values
+                # limit the regression range for each location
+                is_cared_in_the_level = \
+                    (max_deltas >= object_sizes_of_interest[None, :, 0]) & \
+                    (max_deltas <= object_sizes_of_interest[None, :, 1])
+
+                # [N,]
+                tgt_box_area = (tgt_box[:, 2] - tgt_box[:, 0]) * (tgt_box[:, 3] - tgt_box[:, 1])
+                # [N,] -> [N, 1] -> [N, M]
+                gt_positions_area = tgt_box_area.unsqueeze(1).repeat(
+                    1, anchors_over_all_feature_maps.size(0))
+                gt_positions_area[~is_in_boxes] = math.inf
+                gt_positions_area[~is_cared_in_the_level] = math.inf
+
+                # if there are still more than one objects for a position,
+                # we choose the one with minimal area
+                # [M,], each element is the index of ground-truth
+                positions_min_area, gt_matched_idxs = gt_positions_area.min(dim=0)
+
+                # ground truth box regression
+                # [M, 4]
+                gt_anchors_reg_deltas_i = self.get_deltas(
+                    anchors_over_all_feature_maps, tgt_box[gt_matched_idxs])
+
+                # [M,]
+                tgt_cls_i = tgt_cls[gt_matched_idxs]
+                # anchors with area inf are treated as background.
+                tgt_cls_i[positions_min_area == math.inf] = self.num_classes
+
+                # ground truth centerness
+                left_right = gt_anchors_reg_deltas_i[:, [0, 2]]
+                top_bottom = gt_anchors_reg_deltas_i[:, [1, 3]]
+                # [M,]
+                gt_centerness_i = torch.sqrt(
+                    (left_right.min(dim=-1).values / left_right.max(dim=-1).values).clamp_(min=0)
+                    * (top_bottom.min(dim=-1).values / top_bottom.max(dim=-1).values).clamp_(min=0)
+                )
+
+                gt_classes.append(tgt_cls_i)
+                gt_anchors_deltas.append(gt_anchors_reg_deltas_i)
+                gt_centerness.append(gt_centerness_i)
+
+                del centers, center_boxes, deltas, max_deltas, center_deltas
+
+            else:
+                tgt_cls_i = torch.zeros(anchors_over_all_feature_maps.shape[0], device=device) + self.num_classes
+                gt_anchors_reg_deltas_i = torch.zeros([anchors_over_all_feature_maps.shape[0], 4], device=device)
+                gt_centerness_i = torch.zeros(anchors_over_all_feature_maps.shape[0], device=device)
+
+                gt_classes.append(tgt_cls_i.long())
+                gt_anchors_deltas.append(gt_anchors_reg_deltas_i.float())
+                gt_centerness.append(gt_centerness_i.float())
+
+
+        # [B, M], [B, M, 4], [B, M]
+        return torch.stack(gt_classes), torch.stack(gt_anchors_deltas), torch.stack(gt_centerness)
+
+
+class AlignedOTAMatcher(object):
+    """
+    This code referenced to https://github.com/open-mmlab/mmyolo/models/task_modules/assigners/batch_dsl_assigner.py
+    """
+    def __init__(self, num_classes, soft_center_radius=3.0, topk_candidates=13):
+        self.num_classes = num_classes
+        self.soft_center_radius = soft_center_radius
+        self.topk_candidates = topk_candidates
+
+    @torch.no_grad()
+    def __call__(self, 
+                 fpn_strides, 
+                 anchors, 
+                 pred_cls, 
+                 pred_box,
+                 gt_labels,
+                 gt_bboxes):
+        # [M,]
+        strides = torch.cat([torch.ones_like(anchor_i[:, 0]) * stride_i
+                                for stride_i, anchor_i in zip(fpn_strides, anchors)], dim=-1)
+        # List[F, M, 2] -> [M, 2]
+        num_gt = len(gt_labels)
+        anchors = torch.cat(anchors, dim=0)
+
+        # check gt
+        if num_gt == 0 or gt_bboxes.max().item() == 0.:
+            return {
+                'assigned_labels': gt_labels.new_full(pred_cls[..., 0].shape,
+                                                      self.num_classes,
+                                                      dtype=torch.long),
+                'assigned_bboxes': gt_bboxes.new_full(pred_box.shape, 0),
+                'assign_metrics': gt_bboxes.new_full(pred_cls[..., 0].shape, 0)
+            }
+        
+        # get inside points: [N, M]
+        is_in_gt = self.find_inside_points(gt_bboxes, anchors)
+        valid_mask = is_in_gt.sum(dim=0) > 0  # [M,]
+
+        # ----------------------------------- soft center prior -----------------------------------
+        gt_center = (gt_bboxes[..., :2] + gt_bboxes[..., 2:]) / 2.0
+        distance = (anchors.unsqueeze(0) - gt_center.unsqueeze(1)
+                    ).pow(2).sum(-1).sqrt() / strides.unsqueeze(0)  # [N, M]
+        distance = distance * valid_mask.unsqueeze(0)
+        soft_center_prior = torch.pow(10, distance - self.soft_center_radius)
+
+        # ----------------------------------- regression cost -----------------------------------
+        pair_wise_ious, _ = box_iou(gt_bboxes, pred_box)  # [N, M]
+        pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) * 3.0
+
+        # ----------------------------------- classification cost -----------------------------------
+        ## select the predicted scores corresponded to the gt_labels
+        pairwise_pred_scores = pred_cls.permute(1, 0)  # [M, C] -> [C, M]
+        pairwise_pred_scores = pairwise_pred_scores[gt_labels.long(), :].float()   # [N, M]
+        ## scale factor
+        scale_factor = (pair_wise_ious - pairwise_pred_scores.sigmoid()).abs().pow(2.0)
+        ## cls cost
+        pair_wise_cls_loss = F.binary_cross_entropy_with_logits(
+            pairwise_pred_scores, pair_wise_ious,
+            reduction="none") * scale_factor # [N, M]
+            
+        del pairwise_pred_scores
+
+        ## foreground cost matrix
+        cost_matrix = pair_wise_cls_loss + pair_wise_ious_loss + soft_center_prior
+        max_pad_value = torch.ones_like(cost_matrix) * 1e9
+        cost_matrix = torch.where(valid_mask[None].repeat(num_gt, 1),   # [N, M]
+                                  cost_matrix, max_pad_value)
+
+        # ----------------------------------- dynamic label assignment -----------------------------------
+        matched_pred_ious, matched_gt_inds, fg_mask_inboxes = self.dynamic_k_matching(
+            cost_matrix, pair_wise_ious, num_gt)
+        del pair_wise_cls_loss, cost_matrix, pair_wise_ious, pair_wise_ious_loss
+
+        # -----------------------------------process assigned labels -----------------------------------
+        assigned_labels = gt_labels.new_full(pred_cls[..., 0].shape,
+                                             self.num_classes)  # [M,]
+        assigned_labels[fg_mask_inboxes] = gt_labels[matched_gt_inds].squeeze(-1)
+        assigned_labels = assigned_labels.long()  # [M,]
+
+        assigned_bboxes = gt_bboxes.new_full(pred_box.shape, 0)        # [M, 4]
+        assigned_bboxes[fg_mask_inboxes] = gt_bboxes[matched_gt_inds]  # [M, 4]
+
+        assign_metrics = gt_bboxes.new_full(pred_cls[..., 0].shape, 0) # [M,]
+        assign_metrics[fg_mask_inboxes] = matched_pred_ious            # [M,]
+
+        assigned_dict = dict(
+            assigned_labels=assigned_labels,
+            assigned_bboxes=assigned_bboxes,
+            assign_metrics=assign_metrics
+            )
+        
+        return assigned_dict
+
+    def find_inside_points(self, gt_bboxes, anchors):
+        """
+            gt_bboxes: Tensor -> [N, 2]
+            anchors:   Tensor -> [M, 2]
+        """
+        num_anchors = anchors.shape[0]
+        num_gt = gt_bboxes.shape[0]
+
+        anchors_expand = anchors.unsqueeze(0).repeat(num_gt, 1, 1)           # [N, M, 2]
+        gt_bboxes_expand = gt_bboxes.unsqueeze(1).repeat(1, num_anchors, 1)  # [N, M, 4]
+
+        # offset
+        lt = anchors_expand - gt_bboxes_expand[..., :2]
+        rb = gt_bboxes_expand[..., 2:] - anchors_expand
+        bbox_deltas = torch.cat([lt, rb], dim=-1)
+
+        is_in_gts = bbox_deltas.min(dim=-1).values > 0
+
+        return is_in_gts
+    
+    def dynamic_k_matching(self, cost_matrix, pairwise_ious, num_gt):
+        """Use IoU and matching cost to calculate the dynamic top-k positive
+        targets.
+
+        Args:
+            cost_matrix (Tensor): Cost matrix.
+            pairwise_ious (Tensor): Pairwise iou matrix.
+            num_gt (int): Number of gt.
+            valid_mask (Tensor): Mask for valid bboxes.
+        Returns:
+            tuple: matched ious and gt indexes.
+        """
+        matching_matrix = torch.zeros_like(cost_matrix, dtype=torch.uint8)
+        # select candidate topk ious for dynamic-k calculation
+        candidate_topk = min(self.topk_candidates, pairwise_ious.size(1))
+        topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=1)
+        # calculate dynamic k for each gt
+        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
+
+        # sorting the batch cost matirx is faster than topk
+        _, sorted_indices = torch.sort(cost_matrix, dim=1)
+        for gt_idx in range(num_gt):
+            topk_ids = sorted_indices[gt_idx, :dynamic_ks[gt_idx]]
+            matching_matrix[gt_idx, :][topk_ids] = 1
+
+        del topk_ious, dynamic_ks, topk_ids
+
+        prior_match_gt_mask = matching_matrix.sum(0) > 1
+        if prior_match_gt_mask.sum() > 0:
+            cost_min, cost_argmin = torch.min(
+                cost_matrix[:, prior_match_gt_mask], dim=0)
+            matching_matrix[:, prior_match_gt_mask] *= 0
+            matching_matrix[cost_argmin, prior_match_gt_mask] = 1
+
+        # get foreground mask inside box and center prior
+        fg_mask_inboxes = matching_matrix.sum(0) > 0
+        matched_pred_ious = (matching_matrix *
+                             pairwise_ious).sum(0)[fg_mask_inboxes]
+        matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
+
+        return matched_pred_ious, matched_gt_inds, fg_mask_inboxes
+        

+ 148 - 0
yolo/models/fcos/modules.py

@@ -0,0 +1,148 @@
+import torch
+import torch.nn as nn
+from typing import List
+
+
+# --------------------- Basic modules ---------------------
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+        
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+    elif norm_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+
+class BasicConv(nn.Module):
+    def __init__(self, 
+                 in_dim,                   # in channels
+                 out_dim,                  # out channels 
+                 kernel_size=1,            # kernel size 
+                 padding=0,                # padding
+                 stride=1,                 # padding
+                 dilation=1,               # dilation
+                 act_type  :str = 'lrelu', # activation
+                 norm_type :str = 'BN',    # normalization
+                 depthwise :bool = False
+                ):
+        super(BasicConv, self).__init__()
+        self.depthwise = depthwise
+        use_bias = False if norm_type is not None else True
+        if not depthwise:
+            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=1, bias=use_bias)
+            self.norm = get_norm(norm_type, out_dim)
+        else:
+            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim, bias=use_bias)
+            self.norm1 = get_norm(norm_type, in_dim)
+            self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1)
+            self.norm2 = get_norm(norm_type, out_dim)
+        self.act  = get_activation(act_type)
+
+    def forward(self, x):
+        if not self.depthwise:
+            return self.act(self.norm(self.conv(x)))
+        else:
+            # Depthwise conv
+            x = self.norm1(self.conv1(x))
+            # Pointwise conv
+            x = self.act(self.norm2(self.conv2(x)))
+            return x
+
+
+# --------------------- ResNet modules ---------------------
+def conv3x3(in_planes, out_planes, stride=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=1, bias=False)
+
+def conv1x1(in_planes, out_planes, stride=1):
+    """1x1 convolution"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(BasicBlock, self).__init__()
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(Bottleneck, self).__init__()
+        self.conv1 = conv1x1(inplanes, planes)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = conv3x3(planes, planes, stride)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.conv3 = conv1x1(planes, planes * self.expansion)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out

+ 187 - 0
yolo/models/fcos/resnet.py

@@ -0,0 +1,187 @@
+import torch
+import torch.nn as nn
+import torch.utils.model_zoo as model_zoo
+
+try:
+    from .modules import conv1x1, BasicBlock, Bottleneck
+except:
+    from  modules import conv1x1, BasicBlock, Bottleneck
+
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+           'resnet152']
+
+
+model_urls = {
+    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+}
+
+
+# --------------------- ResNet -----------------------
+class ResNet(nn.Module):
+
+    def __init__(self, block, layers, zero_init_residual=False):
+        super(ResNet, self).__init__()
+        self.inplanes = 64
+        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
+                               bias=False)
+        self.bn1 = nn.BatchNorm2d(64)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+        # Zero-initialize the last BN in each residual branch,
+        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+        if zero_init_residual:
+            for m in self.modules():
+                if isinstance(m, Bottleneck):
+                    nn.init.constant_(m.bn3.weight, 0)
+                elif isinstance(m, BasicBlock):
+                    nn.init.constant_(m.bn2.weight, 0)
+
+    def _make_layer(self, block, planes, blocks, stride=1):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                conv1x1(self.inplanes, planes * block.expansion, stride),
+                nn.BatchNorm2d(planes * block.expansion),
+            )
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(block(self.inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        """
+        Input:
+            x: (Tensor) -> [B, C, H, W]
+        Output:
+            c5: (Tensor) -> [B, C, H/32, W/32]
+        """
+        c1 = self.conv1(x)     # [B, C, H/2, W/2]
+        c1 = self.bn1(c1)      # [B, C, H/2, W/2]
+        c1 = self.relu(c1)     # [B, C, H/2, W/2]
+        c2 = self.maxpool(c1)  # [B, C, H/4, W/4]
+
+        c2 = self.layer1(c2)   # [B, C, H/4, W/4]
+        c3 = self.layer2(c2)   # [B, C, H/8, W/8]
+        c4 = self.layer3(c3)   # [B, C, H/16, W/16]
+        c5 = self.layer4(c4)   # [B, C, H/32, W/32]
+
+        return c5
+
+
+# --------------------- Functions -----------------------
+def build_resnet(model_name="resnet18", pretrained=False):
+    if model_name == 'resnet18':
+        model = resnet18(pretrained)
+        feat_dim = 512
+    elif model_name == 'resnet34':
+        model = resnet34(pretrained)
+        feat_dim = 512
+    elif model_name == 'resnet50':
+        model = resnet50(pretrained)
+        feat_dim = 2048
+    elif model_name == 'resnet101':
+        model = resnet34(pretrained)
+        feat_dim = 2048
+    else:
+        raise NotImplementedError("Unknown resnet: {}".format(model_name))
+    
+    return model, feat_dim
+
+def resnet18(pretrained=False, **kwargs):
+    """Constructs a ResNet-18 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
+    if pretrained:
+        # strict = False as we don't need fc layer params.
+        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False)
+    return model
+
+def resnet34(pretrained=False, **kwargs):
+    """Constructs a ResNet-34 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
+    if pretrained:
+        model.load_state_dict(model_zoo.load_url(model_urls['resnet34']), strict=False)
+    return model
+
+def resnet50(pretrained=False, **kwargs):
+    """Constructs a ResNet-50 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
+    if pretrained:
+        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)
+    return model
+
+def resnet101(pretrained=False, **kwargs):
+    """Constructs a ResNet-101 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
+    if pretrained:
+        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']), strict=False)
+    return model
+
+def resnet152(pretrained=False, **kwargs):
+    """Constructs a ResNet-152 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
+    if pretrained:
+        model.load_state_dict(model_zoo.load_url(model_urls['resnet152']), strict=False)
+    return model
+
+
+if __name__=='__main__':
+    import time
+    from thop import profile
+
+    # Build backbone
+    model, _ = build_resnet(model_name='resnet18')
+
+    # Inference
+    x = torch.randn(1, 3, 640, 640)
+    t0 = time.time()
+    output = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print(output.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))    

+ 0 - 24
yolo/models/gelan/build.py

@@ -1,24 +0,0 @@
-import torch.nn as nn
-
-from .loss import SetCriterion
-from .gelan import GElan
-
-
-# build object detector
-def build_gelan(cfg, is_val=False):
-    # -------------- Build YOLO --------------
-    model = GElan(cfg, is_val, deploy=False)
-
-    # -------------- Initialize YOLO --------------
-    for m in model.modules():
-        if isinstance(m, nn.BatchNorm2d):
-            m.eps = 1e-3
-            m.momentum = 0.03    
-            
-    # -------------- Build criterion --------------
-    criterion = None
-    if is_val:
-        # build criterion for training
-        criterion = SetCriterion(cfg)
-        
-    return model, criterion

+ 0 - 165
yolo/models/gelan/gelan.py

@@ -1,165 +0,0 @@
-# --------------- Torch components ---------------
-import torch
-import torch.nn as nn
-
-# --------------- Model components ---------------
-from .gelan_backbone import build_backbone
-from .gelan_neck     import SPPElan
-from .gelan_pafpn    import GElanPaFPN
-from .gelan_head     import GElanDetHead
-from .gelan_pred     import GElanPredLayer
-
-# --------------- External components ---------------
-from utils.misc import multiclass_nms
-
-
-# G-ELAN proposed by YOLOv9
-class GElan(nn.Module):
-    def __init__(self,
-                 cfg,
-                 is_val = False,
-                 deploy = False,
-                 ) -> None:
-        super(GElan, self).__init__()
-        # ---------------------- Basic setting ----------------------
-        self.cfg = cfg
-        self.deploy = deploy
-        self.num_classes = cfg.num_classes
-        ## Post-process parameters
-        self.topk_candidates = cfg.val_topk        if is_val else cfg.test_topk
-        self.conf_thresh     = cfg.val_conf_thresh if is_val else cfg.test_conf_thresh
-        self.nms_thresh      = cfg.val_nms_thresh  if is_val else cfg.test_nms_thresh
-        self.no_multi_labels = False if is_val else True
-        
-        # ---------------------- Network Parameters ----------------------
-        ## Backbone
-        self.backbone = build_backbone(cfg)
-        self.neck     = SPPElan(cfg, self.backbone.feat_dims[-1])
-        self.backbone.feat_dims[-1] = self.neck.out_dim
-        ## PaFPN
-        self.fpn      = GElanPaFPN(cfg, self.backbone.feat_dims)
-        ## Detection head
-        self.head     = GElanDetHead(cfg, self.fpn.out_dims)
-        self.pred     = GElanPredLayer(cfg, self.head.cls_head_dim, self.head.reg_head_dim)
-
-    def switch_to_deploy(self,):
-        for m in self.modules():
-            if hasattr(m, "fuse_convs"):
-                m.fuse_convs()
-
-    def post_process(self, cls_preds, box_preds):
-        """
-        Input:
-            cls_preds: List[torch.Tensor] -> [[B, M, C], ...], B=1
-            box_preds: List[torch.Tensor] -> [[B, M, 4], ...], B=1
-        Output:
-            bboxes: np.array -> [N, 4]
-            scores: np.array -> [N,]
-            labels: np.array -> [N,]
-        """
-        all_scores = []
-        all_labels = []
-        all_bboxes = []
-        
-        for cls_pred_i, box_pred_i in zip(cls_preds, box_preds):
-            cls_pred_i = cls_pred_i[0]
-            box_pred_i = box_pred_i[0]
-            if self.no_multi_labels:
-                # [M,]
-                scores, labels = torch.max(cls_pred_i.sigmoid(), dim=1)
-
-                # Keep top k top scoring indices only.
-                num_topk = min(self.topk_candidates, box_pred_i.size(0))
-
-                # topk candidates
-                predicted_prob, topk_idxs = scores.sort(descending=True)
-                topk_scores = predicted_prob[:num_topk]
-                topk_idxs = topk_idxs[:num_topk]
-
-                # filter out the proposals with low confidence score
-                keep_idxs = topk_scores > self.conf_thresh
-                scores = topk_scores[keep_idxs]
-                topk_idxs = topk_idxs[keep_idxs]
-
-                labels = labels[topk_idxs]
-                bboxes = box_pred_i[topk_idxs]
-            else:
-                # [M, C] -> [MC,]
-                scores_i = cls_pred_i.sigmoid().flatten()
-
-                # Keep top k top scoring indices only.
-                num_topk = min(self.topk_candidates, box_pred_i.size(0))
-
-                # torch.sort is actually faster than .topk (at least on GPUs)
-                predicted_prob, topk_idxs = scores_i.sort(descending=True)
-                topk_scores = predicted_prob[:num_topk]
-                topk_idxs = topk_idxs[:num_topk]
-
-                # filter out the proposals with low confidence score
-                keep_idxs = topk_scores > self.conf_thresh
-                scores = topk_scores[keep_idxs]
-                topk_idxs = topk_idxs[keep_idxs]
-
-                anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
-                labels = topk_idxs % self.num_classes
-
-                bboxes = box_pred_i[anchor_idxs]
-
-            all_scores.append(scores)
-            all_labels.append(labels)
-            all_bboxes.append(bboxes)
-
-        scores = torch.cat(all_scores, dim=0)
-        labels = torch.cat(all_labels, dim=0)
-        bboxes = torch.cat(all_bboxes, dim=0)
-
-        # to cpu & numpy
-        scores = scores.cpu().numpy()
-        labels = labels.cpu().numpy()
-        bboxes = bboxes.cpu().numpy()
-
-        # nms
-        scores, labels, bboxes = multiclass_nms(
-            scores, labels, bboxes, self.nms_thresh, self.num_classes)
-
-        return bboxes, scores, labels
-    
-    def forward(self, x):
-        # ---------------- Backbone ----------------
-        pyramid_feats = self.backbone(x)
-        # ---------------- Neck: SPP ----------------
-        pyramid_feats[-1] = self.neck(pyramid_feats[-1])
-
-        # ---------------- Neck: PaFPN ----------------
-        pyramid_feats = self.fpn(pyramid_feats)
-
-        # ---------------- Heads ----------------
-        cls_feats, reg_feats = self.head(pyramid_feats)
-
-        # ---------------- Preds ----------------
-        outputs = self.pred(cls_feats, reg_feats)
-        outputs['image_size'] = [x.shape[2], x.shape[3]]
-
-        if not self.training:
-            all_cls_preds = outputs['pred_cls']
-            all_box_preds = outputs['pred_box']
-
-            if self.deploy:
-                cls_preds = torch.cat(all_cls_preds, dim=1)[0]
-                box_preds = torch.cat(all_box_preds, dim=1)[0]
-                scores = cls_preds.sigmoid()
-                bboxes = box_preds
-                # [n_anchors_all, 4 + C]
-                outputs = torch.cat([bboxes, scores], dim=-1)
-
-            else:
-                # post process
-                bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
-                outputs = {
-                    "scores": scores,
-                    "labels": labels,
-                    "bboxes": bboxes
-                }
-        
-        return outputs
-    

+ 0 - 198
yolo/models/gelan/gelan_backbone.py

@@ -1,198 +0,0 @@
-import torch
-import torch.nn as nn
-
-try:
-    from .gelan_basic import BasicConv, RepGElanLayer, ADown
-except:
-    from  gelan_basic import BasicConv, RepGElanLayer, ADown
-
-# IN1K pretrained weight
-pretrained_urls = {
-    's': "https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/in1k_pretrained_weight/gelan_s_in1k_68.4.pth",
-    'c': "https://github.com/yjh0410/YOLO-Tutorial-v2/releases/download/in1k_pretrained_weight/gelan_c_in1k_76.7.pth",
-}
-
-# ----------------- GELAN backbone proposed by YOLOv9 -----------------
-class GElanBackbone(nn.Module):
-    def __init__(self, cfg):
-        super(GElanBackbone, self).__init__()
-        # ---------- Basic setting ----------
-        self.model_scale = cfg.scale
-        self.feat_dims = [cfg.backbone_feats["c1"][-1],  # 64
-                          cfg.backbone_feats["c2"][-1],  # 128
-                          cfg.backbone_feats["c3"][-1],  # 256
-                          cfg.backbone_feats["c4"][-1],  # 512
-                          cfg.backbone_feats["c5"][-1],  # 512
-                          ]
-        
-        # ---------- Network setting ----------
-        ## P1/2
-        self.layer_1 = BasicConv(3, cfg.backbone_feats["c1"][0],
-                                 kernel_size=3, padding=1, stride=2,
-                                 act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise)
-        # P2/4
-        self.layer_2 = nn.Sequential(
-            BasicConv(cfg.backbone_feats["c1"][0], cfg.backbone_feats["c2"][0],
-                      kernel_size=3, padding=1, stride=2,
-                      act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
-            RepGElanLayer(in_dim     = cfg.backbone_feats["c2"][0],
-                          inter_dims = cfg.backbone_feats["c2"][1],
-                          out_dim    = cfg.backbone_feats["c2"][2],
-                          num_blocks = cfg.backbone_depth,
-                          shortcut   = True,
-                          act_type   = cfg.bk_act,
-                          norm_type  = cfg.bk_norm,
-                          depthwise  = cfg.bk_depthwise)
-        )
-        # P3/8
-        self.layer_3 = nn.Sequential(
-            ADown(cfg.backbone_feats["c2"][2], cfg.backbone_feats["c3"][0],
-                  act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
-            RepGElanLayer(in_dim     = cfg.backbone_feats["c3"][0],
-                          inter_dims = cfg.backbone_feats["c3"][1],
-                          out_dim    = cfg.backbone_feats["c3"][2],
-                          num_blocks = cfg.backbone_depth,
-                          shortcut   = True,
-                          act_type   = cfg.bk_act,
-                          norm_type  = cfg.bk_norm,
-                          depthwise  = cfg.bk_depthwise)
-        )
-        # P4/16
-        self.layer_4 = nn.Sequential(
-            ADown(cfg.backbone_feats["c3"][2], cfg.backbone_feats["c4"][0],
-                  act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
-            RepGElanLayer(in_dim     = cfg.backbone_feats["c4"][0],
-                          inter_dims = cfg.backbone_feats["c4"][1],
-                          out_dim    = cfg.backbone_feats["c4"][2],
-                          num_blocks = cfg.backbone_depth,
-                          shortcut   = True,
-                          act_type   = cfg.bk_act,
-                          norm_type  = cfg.bk_norm,
-                          depthwise  = cfg.bk_depthwise)
-        )
-        # P5/32
-        self.layer_5 = nn.Sequential(
-            ADown(cfg.backbone_feats["c4"][2], cfg.backbone_feats["c5"][0],
-                  act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
-            RepGElanLayer(in_dim     = cfg.backbone_feats["c5"][0],
-                          inter_dims = cfg.backbone_feats["c5"][1],
-                          out_dim    = cfg.backbone_feats["c5"][2],
-                          num_blocks = cfg.backbone_depth,
-                          shortcut   = True,
-                          act_type   = cfg.bk_act,
-                          norm_type  = cfg.bk_norm,
-                          depthwise  = cfg.bk_depthwise)
-        )
-
-        # Initialize all layers
-        self.init_weights()
-
-        # Load imagenet pretrained weight
-        if cfg.use_pretrained:
-            self.load_pretrained()
-
-    def init_weights(self):
-        """Initialize the parameters."""
-        for m in self.modules():
-            if isinstance(m, torch.nn.Conv2d):
-                m.reset_parameters()
-
-    def load_pretrained(self):
-        url = pretrained_urls[self.model_scale]
-        if url is not None:
-            print('Loading backbone pretrained weight from : {}'.format(url))
-            # checkpoint state dict
-            checkpoint = torch.hub.load_state_dict_from_url(
-                url=url, map_location="cpu", check_hash=True)
-            checkpoint_state_dict = checkpoint.pop("model")
-            # model state dict
-            model_state_dict = self.state_dict()
-            # check
-            for k in list(checkpoint_state_dict.keys()):
-                if k in model_state_dict:
-                    shape_model = tuple(model_state_dict[k].shape)
-                    shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
-                    if shape_model != shape_checkpoint:
-                        checkpoint_state_dict.pop(k)
-                else:
-                    checkpoint_state_dict.pop(k)
-                    print('Unused key: ', k)
-            # load the weight
-            self.load_state_dict(checkpoint_state_dict)
-        else:
-            print('No pretrained weight for model scale: {}.'.format(self.model_scale))
-
-    def forward(self, x):
-        c1 = self.layer_1(x)
-        c2 = self.layer_2(c1)
-        c3 = self.layer_3(c2)
-        c4 = self.layer_4(c3)
-        c5 = self.layer_5(c4)
-        outputs = [c3, c4, c5]
-
-        return outputs
-
-
-# ------------ Functions ------------
-def build_backbone(cfg): 
-    # model
-    if   cfg.backbone == "gelan":
-        backbone = GElanBackbone(cfg)
-    else:
-        raise NotImplementedError("Unknown gelan backbone: {}".format(cfg.backbone))
-        
-    return backbone
-
-
-if __name__ == '__main__':
-    import time
-    from thop import profile
-    class BaseConfig(object):
-        def __init__(self) -> None:
-            self.backbone = 'gelan'
-            self.use_pretrained = True
-            self.bk_act = 'silu'
-            self.bk_norm = 'BN'
-            self.bk_depthwise = False
-            # # Gelan-C scale
-            # self.backbone_feats = {
-            #     "c1": [64],
-            #     "c2": [128, [128, 64], 256],
-            #     "c3": [256, [256, 128], 512],
-            #     "c4": [512, [512, 256], 512],
-            #     "c5": [512, [512, 256], 512],
-            # }
-            # self.scale = "l"
-            # self.backbone_depth = 1
-            # Gelan-S scale
-            self.backbone_feats = {
-                "c1": [32],
-                "c2": [64,  [64, 32],   64],
-                "c3": [64,  [64, 32],   128],
-                "c4": [128, [128, 64],  256],
-                "c5": [256, [256, 128], 256],
-            }
-            self.scale = "s"
-            self.backbone_depth = 3
-    # 定义模型配置文件
-    cfg = BaseConfig()
-
-    # 构建GELAN主干网络
-    model = build_backbone(cfg)
-
-    # 随机生成输入数据
-    x = torch.randn(1, 3, 640, 640)
-
-    # 前向推理
-    outputs = model(x)
-
-    # 打印输出中的shape
-    for out in outputs:
-        print(out.shape)
-
-    # 计算模型的参数量和理论计算量
-    print('============ Params & FLOPs ============')
-    flops, params = profile(model, inputs=(x, ), verbose=False)
-    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))
-    

+ 0 - 312
yolo/models/gelan/gelan_basic.py

@@ -1,312 +0,0 @@
-import numpy as np
-import torch
-import torch.nn as nn
-from typing import List
-
-
-# --------------------- Basic modules ---------------------
-def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
-    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
-
-    return conv
-
-def get_activation(act_type=None):
-    if act_type == 'relu':
-        return nn.ReLU(inplace=True)
-    elif act_type == 'lrelu':
-        return nn.LeakyReLU(0.1, inplace=True)
-    elif act_type == 'mish':
-        return nn.Mish(inplace=True)
-    elif act_type == 'silu':
-        return nn.SiLU(inplace=True)
-    elif act_type is None:
-        return nn.Identity()
-    else:
-        raise NotImplementedError
-        
-def get_norm(norm_type, dim):
-    if norm_type == 'BN':
-        return nn.BatchNorm2d(dim)
-    elif norm_type == 'GN':
-        return nn.GroupNorm(num_groups=32, num_channels=dim)
-    elif norm_type is None:
-        return nn.Identity()
-    else:
-        raise NotImplementedError
-
-class BasicConv(nn.Module):
-    def __init__(self, 
-                 in_dim,                   # in channels
-                 out_dim,                  # out channels 
-                 kernel_size=1,            # kernel size 
-                 padding=0,                # padding
-                 stride=1,                 # padding
-                 dilation=1,               # dilation
-                 group=1,                  # group
-                 act_type  :str = 'lrelu', # activation
-                 norm_type :str = 'BN',    # normalization
-                 depthwise :bool = False
-                ):
-        super(BasicConv, self).__init__()
-        self.depthwise = depthwise
-        if not depthwise:
-            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=group)
-            self.norm = get_norm(norm_type, out_dim)
-        else:
-            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim)
-            self.norm1 = get_norm(norm_type, in_dim)
-            self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1)
-            self.norm2 = get_norm(norm_type, out_dim)
-        self.act  = get_activation(act_type)
-
-    def forward(self, x):
-        if not self.depthwise:
-            return self.act(self.norm(self.conv(x)))
-        else:
-            # Depthwise conv
-            x = self.norm1(self.conv1(x))
-            # Pointwise conv
-            x = self.act(self.norm2(self.conv2(x)))
-            return x
-
-
-# --------------------- GELAN modules (from yolov9) ---------------------
-class ADown(nn.Module):
-    def __init__(self, in_dim, out_dim, act_type="silu", norm_type="BN", depthwise=False):
-        super().__init__()
-        inter_dim = out_dim // 2
-        self.conv_layer_1 = BasicConv(in_dim // 2, inter_dim,
-                                    kernel_size=3, padding=1, stride=2,
-                                    act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        self.conv_layer_2 = BasicConv(in_dim // 2, inter_dim, kernel_size=1,
-                                    act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-    def forward(self, x):
-        x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
-        x1,x2 = x.chunk(2, 1)
-        x1 = self.conv_layer_1(x1)
-        x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1)
-        x2 = self.conv_layer_2(x2)
-
-        return torch.cat((x1, x2), 1)
-
-class RepConvN(nn.Module):
-    """RepConv is a basic rep-style block, including training and deploy status
-    This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
-    """
-    def __init__(self, in_dim, out_dim, k=3, s=1, p=1, g=1, act_type='silu', norm_type='BN', depthwise=False):
-        super().__init__()
-        assert k == 3 and p == 1
-        self.g = g
-        self.in_dim = in_dim
-        self.out_dim = out_dim
-        self.act = get_activation(act_type)
-
-        self.bn = None
-        self.conv1 = BasicConv(in_dim, out_dim,
-                               kernel_size=k, padding=p, stride=s, group=g,
-                               act_type=None, norm_type=norm_type, depthwise=depthwise)
-        self.conv2 = BasicConv(in_dim, out_dim,
-                               kernel_size=1, padding=(p - k // 2), stride=s, group=g,
-                               act_type=None, norm_type=norm_type, depthwise=depthwise)
-
-    def forward(self, x):
-        """Forward process"""
-        if hasattr(self, 'conv'):
-            return self.forward_fuse(x)
-        else:
-            id_out = 0 if self.bn is None else self.bn(x)
-            return self.act(self.conv1(x) + self.conv2(x) + id_out)
-
-    def forward_fuse(self, x):
-        """Forward process"""
-        return self.act(self.conv(x))
-
-    def get_equivalent_kernel_bias(self):
-        kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
-        kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
-        kernelid, biasid = self._fuse_bn_tensor(self.bn)
-        return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
-
-    def _avg_to_3x3_tensor(self, avgp):
-        channels = self.in_dim
-        groups = self.g
-        kernel_size = avgp.kernel_size
-        input_dim = channels // groups
-        k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
-        k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
-        return k
-
-    def _pad_1x1_to_3x3_tensor(self, kernel1x1):
-        if kernel1x1 is None:
-            return 0
-        else:
-            return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
-
-    def _fuse_bn_tensor(self, branch):
-        if branch is None:
-            return 0, 0
-        if isinstance(branch, BasicConv):
-            kernel       = branch.conv.weight
-            running_mean = branch.norm.running_mean
-            running_var  = branch.norm.running_var
-            gamma        = branch.norm.weight
-            beta         = branch.norm.bias
-            eps          = branch.norm.eps
-        elif isinstance(branch, nn.BatchNorm2d):
-            if not hasattr(self, 'id_tensor'):
-                input_dim = self.in_dim // self.g
-                kernel_value = np.zeros((self.in_dim, input_dim, 3, 3), dtype=np.float32)
-                for i in range(self.in_dim):
-                    kernel_value[i, i % input_dim, 1, 1] = 1
-                self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
-            kernel       = self.id_tensor
-            running_mean = branch.running_mean
-            running_var  = branch.running_var
-            gamma        = branch.weight
-            beta         = branch.bias
-            eps          = branch.eps
-        std = (running_var + eps).sqrt()
-        t = (gamma / std).reshape(-1, 1, 1, 1)
-        return kernel * t, beta - running_mean * gamma / std
-
-    def fuse_convs(self):
-        if hasattr(self, 'conv'):
-            return
-        kernel, bias = self.get_equivalent_kernel_bias()
-        self.conv = nn.Conv2d(in_channels  = self.conv1.conv.in_channels,
-                              out_channels = self.conv1.conv.out_channels,
-                              kernel_size  = self.conv1.conv.kernel_size,
-                              stride       = self.conv1.conv.stride,
-                              padding      = self.conv1.conv.padding,
-                              dilation     = self.conv1.conv.dilation,
-                              groups       = self.conv1.conv.groups,
-                              bias         = True).requires_grad_(False)
-        self.conv.weight.data = kernel
-        self.conv.bias.data = bias
-        for para in self.parameters():
-            para.detach_()
-        self.__delattr__('conv1')
-        self.__delattr__('conv2')
-        if hasattr(self, 'nm'):
-            self.__delattr__('nm')
-        if hasattr(self, 'bn'):
-            self.__delattr__('bn')
-        if hasattr(self, 'id_tensor'):
-            self.__delattr__('id_tensor')
-
-class RepNBottleneck(nn.Module):
-    def __init__(self,
-                 in_dim,
-                 out_dim,
-                 shortcut=True,
-                 kernel_size=(3, 3),
-                 expansion=0.5,
-                 act_type='silu',
-                 norm_type='BN',
-                 depthwise=False
-                 ):
-        super().__init__()
-        inter_dim = round(out_dim * expansion)
-        self.conv_layer_1 = RepConvN(in_dim, inter_dim, kernel_size[0], p=kernel_size[0]//2, s=1, act_type=act_type, norm_type=norm_type)
-        self.conv_layer_2 = BasicConv(inter_dim, out_dim, kernel_size[1], padding=kernel_size[1]//2, stride=1, act_type=act_type, norm_type=norm_type)
-        self.add = shortcut and in_dim == out_dim
-
-    def forward(self, x):
-        h = self.conv_layer_2(self.conv_layer_1(x))
-        return x + h if self.add else h
-
-class RepNCSP(nn.Module):
-    def __init__(self,
-                 in_dim,
-                 out_dim,
-                 num_blocks=1,
-                 shortcut=True,
-                 expansion=0.5,
-                 act_type='silu',
-                 norm_type='BN',
-                 depthwise=False
-                 ):
-        super().__init__()
-        inter_dim = int(out_dim * expansion)
-        self.conv_layer_1 = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
-        self.conv_layer_2 = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
-        self.conv_layer_3 = BasicConv(2 * inter_dim, out_dim, kernel_size=1)
-        self.module       = nn.Sequential(*(RepNBottleneck(inter_dim,
-                                                           inter_dim,
-                                                           kernel_size = [3, 3],
-                                                           shortcut    = shortcut,
-                                                           expansion   = 1.0,
-                                                           act_type    = act_type,
-                                                           norm_type   = norm_type,
-                                                           depthwise   = depthwise)
-                                                           for _ in range(num_blocks)))
-
-    def forward(self, x):
-        x1 = self.conv_layer_1(x)
-        x2 = self.module(self.conv_layer_2(x))
-
-        return self.conv_layer_3(torch.cat([x1, x2], dim=1))
-
-class RepGElanLayer(nn.Module):
-    """YOLOv9's GELAN module"""
-    def __init__(self,
-                 in_dim     :int,
-                 inter_dims :List,
-                 out_dim    :int,
-                 num_blocks :int   = 1,
-                 shortcut   :bool  = False,
-                 act_type   :str   = 'silu',
-                 norm_type  :str   = 'BN',
-                 depthwise  :bool  = False,
-                 ) -> None:
-        super(RepGElanLayer, self).__init__()
-        # ----------- Basic parameters -----------
-        self.in_dim = in_dim
-        self.inter_dims = inter_dims
-        self.out_dim = out_dim
-
-        # ----------- Network parameters -----------
-        self.conv_layer_1  = BasicConv(in_dim, inter_dims[0], kernel_size=1, act_type=act_type, norm_type=norm_type)
-        self.elan_module_1 = nn.Sequential(
-             RepNCSP(inter_dims[0]//2,
-                     inter_dims[1],
-                     num_blocks  = num_blocks,
-                     shortcut    = shortcut,
-                     expansion   = 0.5,
-                     act_type    = act_type,
-                     norm_type   = norm_type,
-                     depthwise   = depthwise),
-            BasicConv(inter_dims[1], inter_dims[1],
-                      kernel_size=3, padding=1,
-                      act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-        self.elan_module_2 = nn.Sequential(
-             RepNCSP(inter_dims[1],
-                     inter_dims[1],
-                     num_blocks  = num_blocks,
-                     shortcut    = shortcut,
-                     expansion   = 0.5,
-                     act_type    = act_type,
-                     norm_type   = norm_type,
-                     depthwise   = depthwise),
-            BasicConv(inter_dims[1], inter_dims[1],
-                      kernel_size=3, padding=1,
-                      act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        )
-        self.conv_layer_2 = BasicConv(inter_dims[0] + 2*self.inter_dims[1], out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
-
-
-    def forward(self, x):
-        # Input proj
-        x1, x2 = torch.chunk(self.conv_layer_1(x), 2, dim=1)
-        out = list([x1, x2])
-
-        # ELAN module
-        out.append(self.elan_module_1(out[-1]))
-        out.append(self.elan_module_2(out[-1]))
-
-        # Output proj
-        out = self.conv_layer_2(torch.cat(out, dim=1))
-
-        return out
-    

+ 0 - 176
yolo/models/gelan/gelan_head.py

@@ -1,176 +0,0 @@
-import torch
-import torch.nn as nn
-
-try:
-    from .gelan_basic import BasicConv
-except:
-    from  gelan_basic import BasicConv
-    
-
-# Single-level Head
-class SingleLevelHead(nn.Module):
-    def __init__(self,
-                 in_dim       :int  = 256,
-                 cls_head_dim :int  = 256,
-                 reg_head_dim :int  = 256,
-                 num_cls_head :int  = 2,
-                 num_reg_head :int  = 2,
-                 act_type     :str  = "silu",
-                 norm_type    :str  = "BN",
-                 depthwise    :bool = False):
-        super().__init__()
-        # --------- Basic Parameters ----------
-        self.in_dim = in_dim
-        self.num_cls_head = num_cls_head
-        self.num_reg_head = num_reg_head
-        self.act_type = act_type
-        self.norm_type = norm_type
-        self.depthwise = depthwise
-        
-        # --------- Network Parameters ----------
-        ## cls head
-        cls_feats = []
-        self.cls_head_dim = cls_head_dim
-        for i in range(num_cls_head):
-            if i == 0:
-                cls_feats.append(
-                    BasicConv(in_dim, self.cls_head_dim,
-                              kernel_size=3, padding=1, stride=1, 
-                              act_type=act_type,
-                              norm_type=norm_type,
-                              depthwise=depthwise)
-                              )
-            else:
-                cls_feats.append(
-                    BasicConv(self.cls_head_dim, self.cls_head_dim,
-                              kernel_size=3, padding=1, stride=1, 
-                              act_type=act_type,
-                              norm_type=norm_type,
-                              depthwise=depthwise)
-                              )
-        ## reg head
-        reg_feats = []
-        self.reg_head_dim = reg_head_dim
-        for i in range(num_reg_head):
-            if i == 0:
-                reg_feats.append(
-                    BasicConv(in_dim, self.reg_head_dim,
-                              kernel_size=3, padding=1, stride=1, 
-                              act_type=act_type,
-                              norm_type=norm_type,
-                              depthwise=depthwise)
-                              )
-            else:
-                reg_feats.append(
-                    BasicConv(self.reg_head_dim, self.reg_head_dim,
-                              kernel_size=3, padding=1, stride=1, group=4,
-                              act_type=act_type,
-                              norm_type=norm_type,
-                              depthwise=depthwise)
-                              )
-        self.cls_feats = nn.Sequential(*cls_feats)
-        self.reg_feats = nn.Sequential(*reg_feats)
-
-        self.init_weights()
-        
-    def init_weights(self):
-        """Initialize the parameters."""
-        for m in self.modules():
-            if isinstance(m, torch.nn.Conv2d):
-                # In order to be consistent with the source code,
-                # reset the Conv2d initialization parameters
-                m.reset_parameters()
-
-    def forward(self, x):
-        """
-            in_feats: (Tensor) [B, C, H, W]
-        """
-        cls_feats = self.cls_feats(x)
-        reg_feats = self.reg_feats(x)
-
-        return cls_feats, reg_feats
-    
-# Multi-level Head
-class GElanDetHead(nn.Module):
-    def __init__(self, cfg, in_dims):
-        super().__init__()
-        ## ----------- Network Parameters -----------
-        self.multi_level_heads = nn.ModuleList(
-            [SingleLevelHead(in_dim       = in_dims[level],
-                             cls_head_dim = max(in_dims[0], min(cfg.num_classes * 2, 128)),
-                             reg_head_dim = max(in_dims[0]//4, 16, 4*cfg.reg_max),
-                             num_cls_head = cfg.num_cls_head,
-                             num_reg_head = cfg.num_reg_head,
-                             act_type     = cfg.head_act,
-                             norm_type    = cfg.head_norm,
-                             depthwise    = cfg.head_depthwise)
-                             for level in range(cfg.num_levels)
-                             ])
-        # --------- Basic Parameters ----------
-        self.in_dims = in_dims
-        self.cls_head_dim = self.multi_level_heads[0].cls_head_dim
-        self.reg_head_dim = self.multi_level_heads[0].reg_head_dim
-
-
-    def forward(self, feats):
-        """
-            feats: List[(Tensor)] [[B, C, H, W], ...]
-        """
-        cls_feats = []
-        reg_feats = []
-        for feat, head in zip(feats, self.multi_level_heads):
-            # ---------------- Pred ----------------
-            cls_feat, reg_feat = head(feat)
-
-            cls_feats.append(cls_feat)
-            reg_feats.append(reg_feat)
-
-        return cls_feats, reg_feats
-    
-
-
-if __name__=='__main__':
-    import time
-    from thop import profile
-    # Model config
-    
-    # GElan-Base config
-    class GElanBaseConfig(object):
-        def __init__(self) -> None:
-            # ---------------- Model config ----------------
-            self.reg_max  = 16
-            self.out_stride = [8, 16, 32]
-            self.max_stride = 32
-            self.num_levels = 3
-            ## Head
-            self.head_act  = 'lrelu'
-            self.head_norm = 'BN'
-            self.head_depthwise = False
-            self.num_cls_head   = 2
-            self.num_reg_head   = 2
-
-    cfg = GElanBaseConfig()
-    cfg.num_classes = 20
-
-    # Build a head
-    fpn_dims = [128, 256, 256]
-    pyramid_feats = [torch.randn(1, fpn_dims[0], 80, 80),
-                     torch.randn(1, fpn_dims[1], 40, 40),
-                     torch.randn(1, fpn_dims[2], 20, 20)]
-    head = GElanDetHead(cfg, fpn_dims)
-
-
-    # Inference
-    t0 = time.time()
-    cls_feats, reg_feats = head(pyramid_feats)
-    t1 = time.time()
-    print('Time: ', t1 - t0)
-    print("====== GElan Head output ======")
-    for level, (cls_f, reg_f) in enumerate(zip(cls_feats, reg_feats)):
-        print("- Level-{} : ".format(level), cls_f.shape, reg_f.shape)
-
-    flops, params = profile(head, inputs=(pyramid_feats, ), verbose=False)
-    print('==============================')
-    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))
-    

+ 0 - 76
yolo/models/gelan/gelan_neck.py

@@ -1,76 +0,0 @@
-import torch
-import torch.nn as nn
-
-from .gelan_basic import BasicConv
-
-
-# SPPF (from yolov5)
-class SPPF(nn.Module):
-    """
-        This code referenced to https://github.com/ultralytics/yolov5
-    """
-    def __init__(self, cfg, in_dim, out_dim):
-        super().__init__()
-        ## ----------- Basic Parameters -----------
-        inter_dim = round(in_dim * cfg.neck_expand_ratio)
-        self.out_dim = out_dim
-        ## ----------- Network Parameters -----------
-        self.cv1 = BasicConv(in_dim, inter_dim,
-                             kernel_size=1, padding=0, stride=1,
-                             act_type=cfg.neck_act, norm_type=cfg.neck_norm)
-        self.cv2 = BasicConv(inter_dim * 4, out_dim,
-                             kernel_size=1, padding=0, stride=1,
-                             act_type=cfg.neck_act, norm_type=cfg.neck_norm)
-        self.m = nn.MaxPool2d(kernel_size=cfg.spp_pooling_size,
-                              stride=1,
-                              padding=cfg.spp_pooling_size // 2)
-
-        # Initialize all layers
-        self.init_weights()
-
-    def init_weights(self):
-        """Initialize the parameters."""
-        for m in self.modules():
-            if isinstance(m, torch.nn.Conv2d):
-                # In order to be consistent with the source code,
-                # reset the Conv2d initialization parameters
-                m.reset_parameters()
-
-    def forward(self, x):
-        x = self.cv1(x)
-        y1 = self.m(x)
-        y2 = self.m(y1)
-
-        return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
-
-# SPP-ELAN (from yolov9)
-class SPPElan(nn.Module):
-    def __init__(self, cfg, in_dim):
-        """SPPElan looks like the SPPF."""
-        super().__init__()
-        ## ----------- Basic Parameters -----------
-        self.in_dim = in_dim
-        self.inter_dim = cfg.spp_inter_dim
-        self.out_dim   = cfg.spp_out_dim
-        ## ----------- Network Parameters -----------
-        self.conv_layer_1 = BasicConv(in_dim, self.inter_dim, kernel_size=1, act_type=cfg.neck_act, norm_type=cfg.neck_norm)
-        self.conv_layer_2 = BasicConv(self.inter_dim * 4, self.out_dim, kernel_size=1, act_type=cfg.neck_act, norm_type=cfg.neck_norm)
-        self.pool_layer   = nn.MaxPool2d(kernel_size=cfg.spp_pooling_size, stride=1, padding=cfg.spp_pooling_size // 2)
-
-        # Initialize all layers
-        self.init_weights()
-
-    def init_weights(self):
-        """Initialize the parameters."""
-        for m in self.modules():
-            if isinstance(m, torch.nn.Conv2d):
-                # In order to be consistent with the source code,
-                # reset the Conv2d initialization parameters
-                m.reset_parameters()
-
-    def forward(self, x):
-        y = [self.conv_layer_1(x)]
-        y.extend(self.pool_layer(y[-1]) for _ in range(3))
-        
-        return self.conv_layer_2(torch.cat(y, 1))
-    

+ 0 - 158
yolo/models/gelan/gelan_pafpn.py

@@ -1,158 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from typing import List
-
-try:
-    from .gelan_basic import RepGElanLayer, ADown
-except:
-    from  gelan_basic import RepGElanLayer, ADown
-
-
-# PaFPN-ELAN
-class GElanPaFPN(nn.Module):
-    def __init__(self,
-                 cfg,
-                 in_dims :List = [256, 512, 256],
-                 ) -> None:
-        super(GElanPaFPN, self).__init__()
-        print('==============================')
-        print('FPN: {}'.format("GELAN PaFPN"))
-        # --------------------------- Basic Parameters ---------------------------
-        self.in_dims = in_dims[::-1]
-        self.out_dims = [cfg.fpn_feats_td["p3"][1], cfg.fpn_feats_bu["p4"][1], cfg.fpn_feats_bu["p5"][1]]
-
-        # ---------------- Top dwon ----------------
-        ## P5 -> P4
-        self.top_down_layer_1 = RepGElanLayer(in_dim     = self.in_dims[0] + self.in_dims[1],
-                                              inter_dims = cfg.fpn_feats_td["p4"][0],
-                                              out_dim    = cfg.fpn_feats_td["p4"][1],
-                                              num_blocks = cfg.fpn_depth,
-                                              shortcut   = False,
-                                              act_type   = cfg.fpn_act,
-                                              norm_type  = cfg.fpn_norm,
-                                              depthwise  = cfg.fpn_depthwise,
-                                              )
-        ## P4 -> P3
-        self.top_down_layer_2 = RepGElanLayer(in_dim     = cfg.fpn_feats_td["p4"][1] + self.in_dims[2],
-                                              inter_dims = cfg.fpn_feats_td["p3"][0],
-                                              out_dim    = cfg.fpn_feats_td["p3"][1],
-                                              num_blocks = cfg.fpn_depth,
-                                              shortcut   = False,
-                                              act_type   = cfg.fpn_act,
-                                              norm_type  = cfg.fpn_norm,
-                                              depthwise  = cfg.fpn_depthwise,
-                                              )
-        # ---------------- Bottom up ----------------
-        ## P3 -> P4
-        self.dowmsample_layer_1 = ADown(cfg.fpn_feats_td["p3"][1], cfg.fpn_feats_td["p3"][1],
-                                        act_type=cfg.fpn_act, norm_type=cfg.fpn_norm, depthwise=cfg.fpn_depthwise)
-        self.bottom_up_layer_1  = RepGElanLayer(in_dim     = cfg.fpn_feats_td["p3"][1] + cfg.fpn_feats_td["p4"][1],
-                                                inter_dims = cfg.fpn_feats_bu["p4"][0],
-                                                out_dim    = cfg.fpn_feats_bu["p4"][1],
-                                                num_blocks = cfg.fpn_depth,
-                                                shortcut   = False,
-                                                act_type   = cfg.fpn_act,
-                                                norm_type  = cfg.fpn_norm,
-                                                depthwise  = cfg.fpn_depthwise,
-                                                )
-        ## P4 -> P5
-        self.dowmsample_layer_2 = ADown(cfg.fpn_feats_bu["p4"][1], cfg.fpn_feats_bu["p4"][1],
-                                        act_type=cfg.fpn_act, norm_type=cfg.fpn_norm, depthwise=cfg.fpn_depthwise)
-        self.bottom_up_layer_2  = RepGElanLayer(in_dim     = cfg.fpn_feats_td["p4"][1] + self.in_dims[0],
-                                                inter_dims = cfg.fpn_feats_bu["p5"][0],
-                                                out_dim    = cfg.fpn_feats_bu["p5"][1],
-                                                num_blocks = cfg.fpn_depth,
-                                                shortcut   = False,
-                                                act_type   = cfg.fpn_act,
-                                                norm_type  = cfg.fpn_norm,
-                                                depthwise  = cfg.fpn_depthwise,
-                                                )
-        
-        self.init_weights()
-        
-    def init_weights(self):
-        """Initialize the parameters."""
-        for m in self.modules():
-            if isinstance(m, torch.nn.Conv2d):
-                # In order to be consistent with the source code,
-                # reset the Conv2d initialization parameters
-                m.reset_parameters()
-
-    def forward(self, features):
-        c3, c4, c5 = features
-
-        # ------------------ Top down FPN ------------------
-        ## P5 -> P4
-        p5_up = F.interpolate(c5, scale_factor=2.0)
-        p4 = self.top_down_layer_1(torch.cat([p5_up, c4], dim=1))
-
-        ## P4 -> P3
-        p4_up = F.interpolate(p4, scale_factor=2.0)
-        p3 = self.top_down_layer_2(torch.cat([p4_up, c3], dim=1))
-
-        # ------------------ Bottom up FPN ------------------
-        ## p3 -> P4
-        p3_ds = self.dowmsample_layer_1(p3)
-        p4 = self.bottom_up_layer_1(torch.cat([p3_ds, p4], dim=1))
-
-        ## P4 -> 5
-        p4_ds = self.dowmsample_layer_2(p4)
-        p5 = self.bottom_up_layer_2(torch.cat([p4_ds, c5], dim=1))
-
-        out_feats = [p3, p4, p5] # [P3, P4, P5]
-
-        return out_feats
-
-
-if __name__=='__main__':
-    import time
-    from thop import profile
-    # Model config
-    
-    # GElan-Base config
-    class GElanBaseConfig(object):
-        def __init__(self) -> None:
-            # ---------------- Model config ----------------
-            self.width    = 0.50
-            self.depth    = 0.34
-            self.ratio    = 2.0
-            self.out_stride = [8, 16, 32]
-            self.max_stride = 32
-            self.num_levels = 3
-            ## FPN
-            self.fpn      = 'gelan_pafpn'
-            self.fpn_act  = 'silu'
-            self.fpn_norm = 'BN'
-            self.fpn_depthwise = False
-            self.fpn_depth    = 3
-            self.fpn_feats_td = {
-                "p4": [[256, 128], 256],
-                "p3": [[128, 64],  128],
-            }
-            self.fpn_feats_bu = {
-                "p4": [[256, 128], 256],
-                "p5": [[256, 128], 256],
-            }
-
-    cfg = GElanBaseConfig()
-    # Build a head
-    in_dims  = [128, 256, 256]
-    fpn = GElanPaFPN(cfg, in_dims)
-
-    # Inference
-    x = [torch.randn(1, in_dims[0], 80, 80),
-         torch.randn(1, in_dims[1], 40, 40),
-         torch.randn(1, in_dims[2], 20, 20)]
-    t0 = time.time()
-    output = fpn(x)
-    t1 = time.time()
-    print('Time: ', t1 - t0)
-    print('====== FPN output ====== ')
-    for level, feat in enumerate(output):
-        print("- Level-{} : ".format(level), feat.shape)
-
-    flops, params = profile(fpn, inputs=(x, ), verbose=False)
-    print('==============================')
-    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))

+ 0 - 155
yolo/models/gelan/gelan_pred.py

@@ -1,155 +0,0 @@
-import math
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-# Single-level pred layer
-class SingleLevelPredLayer(nn.Module):
-    def __init__(self,
-                 cls_dim     :int = 256,
-                 reg_dim     :int = 256,
-                 stride      :int = 32,
-                 reg_max     :int = 16,
-                 num_classes :int = 80,
-                 num_coords  :int = 4):
-        super().__init__()
-        # --------- Basic Parameters ----------
-        self.stride = stride
-        self.cls_dim = cls_dim
-        self.reg_dim = reg_dim
-        self.reg_max = reg_max
-        self.num_classes = num_classes
-        self.num_coords = num_coords
-
-        # --------- Network Parameters ----------
-        self.cls_pred = nn.Conv2d(cls_dim, num_classes, kernel_size=1)
-        self.reg_pred = nn.Conv2d(reg_dim, num_coords, kernel_size=1, groups=4)                
-
-        self.init_bias()
-        
-    def init_bias(self):
-        # cls pred bias
-        b = self.cls_pred.bias.view(1, -1)
-        b.data.fill_(math.log(5 / self.num_classes / (640. / self.stride) ** 2))
-        self.cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        # reg pred bias
-        b = self.reg_pred.bias.view(-1, )
-        b.data.fill_(1.0)
-        self.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        w = self.reg_pred.weight
-        w.data.fill_(0.)
-        self.reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
-
-    def generate_anchors(self, fmp_size):
-        """
-            fmp_size: (List) [H, W]
-        """
-        # generate grid cells
-        fmp_h, fmp_w = fmp_size
-        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
-        # [H, W, 2] -> [HW, 2]
-        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
-        anchors += 0.5  # add center offset
-        anchors *= self.stride
-
-        return anchors
-        
-    def forward(self, cls_feat, reg_feat):
-        # pred
-        cls_pred = self.cls_pred(cls_feat)
-        reg_pred = self.reg_pred(reg_feat)
-
-        # generate anchor boxes: [M, 4]
-        B, _, H, W = cls_pred.size()
-        fmp_size = [H, W]
-        anchors = self.generate_anchors(fmp_size)
-        anchors = anchors.to(cls_pred.device)
-        # stride tensor: [M, 1]
-        stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride
-        
-        # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
-        cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
-        reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
-        
-        # output dict
-        outputs = {"pred_cls": cls_pred,            # List(Tensor) [B, M, C]
-                   "pred_reg": reg_pred,            # List(Tensor) [B, M, 4*(reg_max)]
-                   "anchors": anchors,              # List(Tensor) [M, 2]
-                   "strides": self.stride,          # List(Int) = [8, 16, 32]
-                   "stride_tensor": stride_tensor   # List(Tensor) [M, 1]
-                   }
-
-        return outputs
-
-# Multi-level pred layer
-class GElanPredLayer(nn.Module):
-    def __init__(self,
-                 cfg,
-                 cls_dim,
-                 reg_dim,
-                 ):
-        super().__init__()
-        # --------- Basic Parameters ----------
-        self.cfg = cfg
-        self.cls_dim = cls_dim
-        self.reg_dim = reg_dim
-
-        # ----------- Network Parameters -----------
-        ## pred layers
-        self.multi_level_preds = nn.ModuleList(
-            [SingleLevelPredLayer(cls_dim     = cls_dim,
-                                  reg_dim     = reg_dim,
-                                  stride      = cfg.out_stride[level],
-                                  reg_max     = cfg.reg_max,
-                                  num_classes = cfg.num_classes,
-                                  num_coords  = 4 * cfg.reg_max)
-                                  for level in range(cfg.num_levels)
-                                  ])
-        ## proj conv
-        proj_init = torch.arange(cfg.reg_max, dtype=torch.float)
-        self.proj_conv = nn.Conv2d(cfg.reg_max, 1, kernel_size=1, bias=False).requires_grad_(False)
-        self.proj_conv.weight.data[:] = nn.Parameter(proj_init.view([1, cfg.reg_max, 1, 1]), requires_grad=False)
-
-    def forward(self, cls_feats, reg_feats):
-        all_anchors = []
-        all_strides = []
-        all_cls_preds = []
-        all_reg_preds = []
-        all_box_preds = []
-        for level in range(self.cfg.num_levels):
-            # -------------- Single-level prediction --------------
-            outputs = self.multi_level_preds[level](cls_feats[level], reg_feats[level])
-
-            # -------------- Decode bbox --------------
-            B, M = outputs["pred_reg"].shape[:2]
-            # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max]
-            delta_pred = outputs["pred_reg"].reshape([B, M, 4, self.cfg.reg_max])
-            # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
-            delta_pred = delta_pred.permute(0, 3, 2, 1).contiguous()
-            # [B, reg_max, 4, M] -> [B, 1, 4, M]
-            delta_pred = self.proj_conv(F.softmax(delta_pred, dim=1))
-            # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
-            delta_pred = delta_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
-            ## tlbr -> xyxy
-            x1y1_pred = outputs["anchors"][None] - delta_pred[..., :2] * self.cfg.out_stride[level]
-            x2y2_pred = outputs["anchors"][None] + delta_pred[..., 2:] * self.cfg.out_stride[level]
-            box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
-
-            # collect results
-            all_cls_preds.append(outputs["pred_cls"])
-            all_reg_preds.append(outputs["pred_reg"])
-            all_box_preds.append(box_pred)
-            all_anchors.append(outputs["anchors"])
-            all_strides.append(outputs["stride_tensor"])
-        
-        # output dict
-        outputs = {"pred_cls":      all_cls_preds,         # List(Tensor) [B, M, C]
-                   "pred_reg":      all_reg_preds,         # List(Tensor) [B, M, 4*(reg_max)]
-                   "pred_box":      all_box_preds,         # List(Tensor) [B, M, 4]
-                   "anchors":       all_anchors,           # List(Tensor) [M, 2]
-                   "stride_tensor": all_strides,           # List(Tensor) [M, 1]
-                   "strides":       self.cfg.out_stride,   # List(Int) = [8, 16, 32]
-                   }
-
-        return outputs

+ 0 - 187
yolo/models/gelan/loss.py

@@ -1,187 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from utils.box_ops import bbox2dist, bbox_iou
-from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
-
-from .matcher import TaskAlignedAssigner
-
-
-class SetCriterion(object):
-    def __init__(self, cfg):
-        # --------------- Basic parameters ---------------
-        self.cfg = cfg
-        self.reg_max = cfg.reg_max
-        self.num_classes = cfg.num_classes
-        # --------------- Loss config ---------------
-        self.loss_cls_weight = cfg.loss_cls
-        self.loss_box_weight = cfg.loss_box
-        self.loss_dfl_weight = cfg.loss_dfl
-        # --------------- Matcher config ---------------
-        self.matcher = TaskAlignedAssigner(num_classes     = cfg.num_classes,
-                                           topk_candidates = cfg.tal_topk_candidates,
-                                           alpha           = cfg.tal_alpha,
-                                           beta            = cfg.tal_beta
-                                           )
-
-    def loss_classes(self, pred_cls, gt_score):
-        # compute bce loss
-        loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_score, reduction='none')
-
-        return loss_cls
-    
-    def loss_bboxes(self, pred_box, gt_box, bbox_weight):
-        # regression loss
-        ious = bbox_iou(pred_box, gt_box, xywh=False, CIoU=True)
-        loss_box = (1.0 - ious.squeeze(-1)) * bbox_weight
-
-        return loss_box
-    
-    def loss_dfl(self, pred_reg, gt_box, anchor, stride, bbox_weight=None):
-        # rescale coords by stride
-        gt_box_s = gt_box / stride
-        anchor_s = anchor / stride
-
-        # compute deltas
-        gt_ltrb_s = bbox2dist(anchor_s, gt_box_s, self.reg_max - 1)
-
-        gt_left = gt_ltrb_s.to(torch.long)
-        gt_right = gt_left + 1
-
-        weight_left = gt_right.to(torch.float) - gt_ltrb_s
-        weight_right = 1 - weight_left
-
-        # loss left
-        loss_left = F.cross_entropy(
-            pred_reg.view(-1, self.reg_max),
-            gt_left.view(-1),
-            reduction='none').view(gt_left.shape) * weight_left
-        # loss right
-        loss_right = F.cross_entropy(
-            pred_reg.view(-1, self.reg_max),
-            gt_right.view(-1),
-            reduction='none').view(gt_left.shape) * weight_right
-
-        loss_dfl = (loss_left + loss_right).mean(-1)
-        
-        if bbox_weight is not None:
-            loss_dfl *= bbox_weight
-
-        return loss_dfl
-
-    def __call__(self, outputs, targets):        
-        """
-            outputs['pred_cls']: List(Tensor) [B, M, C]
-            outputs['pred_reg']: List(Tensor) [B, M, 4*(reg_max+1)]
-            outputs['pred_box']: List(Tensor) [B, M, 4]
-            outputs['anchors']: List(Tensor) [M, 2]
-            outputs['strides']: List(Int) [8, 16, 32] output stride
-            outputs['stride_tensor']: List(Tensor) [M, 1]
-            targets: (List) [dict{'boxes': [...], 
-                                 'labels': [...], 
-                                 'orig_size': ...}, ...]
-        """
-        # preds: [B, M, C]
-        cls_preds = torch.cat(outputs['pred_cls'], dim=1)
-        reg_preds = torch.cat(outputs['pred_reg'], dim=1)
-        box_preds = torch.cat(outputs['pred_box'], dim=1)
-        bs, num_anchors = cls_preds.shape[:2]
-        device = cls_preds.device
-        anchors = torch.cat(outputs['anchors'], dim=0)
-        
-        # --------------- label assignment ---------------
-        gt_score_targets = []
-        gt_bbox_targets = []
-        fg_masks = []
-        for batch_idx in range(bs):
-            tgt_labels = targets[batch_idx]["labels"].to(device)     # [Mp,]
-            tgt_boxs = targets[batch_idx]["boxes"].to(device)        # [Mp, 4]
-
-            if self.cfg.normalize_coords:
-                img_h, img_w = outputs['image_size']
-                tgt_boxs[..., [0, 2]] *= img_w
-                tgt_boxs[..., [1, 3]] *= img_h
-            
-            if self.cfg.box_format == 'xywh':
-                tgt_boxs_x1y1 = tgt_boxs[..., :2] - 0.5 * tgt_boxs[..., 2:]
-                tgt_boxs_x2y2 = tgt_boxs[..., :2] + 0.5 * tgt_boxs[..., 2:]
-                tgt_boxs = torch.cat([tgt_boxs_x1y1, tgt_boxs_x2y2], dim=-1)
-
-            # check target
-            if len(tgt_labels) == 0 or tgt_boxs.max().item() == 0.:
-                # There is no valid gt
-                fg_mask  = cls_preds.new_zeros(1, num_anchors).bool()               #[1, M,]
-                gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
-                gt_box   = cls_preds.new_zeros((1, num_anchors, 4))                  #[1, M, 4]
-            else:
-                tgt_labels = tgt_labels[None, :, None]      # [1, Mp, 1]
-                tgt_boxs = tgt_boxs[None]                   # [1, Mp, 4]
-                (
-                    _,
-                    gt_box,     # [1, M, 4]
-                    gt_score,   # [1, M, C]
-                    fg_mask,    # [1, M,]
-                    _
-                ) = self.matcher(
-                    pd_scores = cls_preds[batch_idx:batch_idx+1].detach().sigmoid(), 
-                    pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
-                    anc_points = anchors,
-                    gt_labels = tgt_labels,
-                    gt_bboxes = tgt_boxs
-                    )
-            gt_score_targets.append(gt_score)
-            gt_bbox_targets.append(gt_box)
-            fg_masks.append(fg_mask)
-
-        # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
-        fg_masks = torch.cat(fg_masks, 0).view(-1)                                    # [BM,]
-        gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes)  # [BM, C]
-        gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4)                   # [BM, 4]
-        num_fgs = gt_score_targets.sum()
-        
-        # Average loss normalizer across all the GPUs
-        if is_dist_avail_and_initialized():
-            torch.distributed.all_reduce(num_fgs)
-        num_fgs = (num_fgs / get_world_size()).clamp(1.0)
-
-        # ------------------ Classification loss ------------------
-        cls_preds = cls_preds.view(-1, self.num_classes)
-        loss_cls = self.loss_classes(cls_preds, gt_score_targets)
-        loss_cls = loss_cls.sum() / num_fgs
-
-        # ------------------ Regression loss ------------------
-        box_preds_pos = box_preds.view(-1, 4)[fg_masks]
-        box_targets_pos = gt_bbox_targets.view(-1, 4)[fg_masks]
-        bbox_weight = gt_score_targets[fg_masks].sum(-1)
-        loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos, bbox_weight)
-        loss_box = loss_box.sum() / num_fgs
-
-        # ------------------ Distribution focal loss  ------------------
-        ## process anchors
-        anchors = anchors[None].repeat(bs, 1, 1).view(-1, 2)
-        ## process stride tensors
-        strides = torch.cat(outputs['stride_tensor'], dim=0)
-        strides = strides.unsqueeze(0).repeat(bs, 1, 1).view(-1, 1)
-        ## fg preds
-        reg_preds_pos = reg_preds.view(-1, 4*self.reg_max)[fg_masks]
-        anchors_pos = anchors[fg_masks]
-        strides_pos = strides[fg_masks]
-        ## compute dfl
-        loss_dfl = self.loss_dfl(reg_preds_pos, box_targets_pos, anchors_pos, strides_pos, bbox_weight)
-        loss_dfl = loss_dfl.sum() / num_fgs
-
-        # total loss
-        losses = loss_cls * self.loss_cls_weight + loss_box * self.loss_box_weight + loss_dfl * self.loss_dfl_weight
-        loss_dict = dict(
-                loss_cls = loss_cls,
-                loss_box = loss_box,
-                loss_dfl = loss_dfl,
-                losses = losses
-        )
-
-        return loss_dict
-    
-
-if __name__ == "__main__":
-    pass

+ 0 - 199
yolo/models/gelan/matcher.py

@@ -1,199 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from utils.box_ops import bbox_iou
-
-
-# -------------------------- Task Aligned Assigner --------------------------
-class TaskAlignedAssigner(nn.Module):
-    def __init__(self,
-                 num_classes     = 80,
-                 topk_candidates = 10,
-                 alpha           = 0.5,
-                 beta            = 6.0, 
-                 eps             = 1e-9):
-        super(TaskAlignedAssigner, self).__init__()
-        self.topk_candidates = topk_candidates
-        self.num_classes = num_classes
-        self.bg_idx = num_classes
-        self.alpha = alpha
-        self.beta = beta
-        self.eps = eps
-
-    @torch.no_grad()
-    def forward(self,
-                pd_scores,
-                pd_bboxes,
-                anc_points,
-                gt_labels,
-                gt_bboxes):
-        self.bs = pd_scores.size(0)
-        self.n_max_boxes = gt_bboxes.size(1)
-
-        mask_pos, align_metric, overlaps = self.get_pos_mask(
-            pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points)
-
-        target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(
-            mask_pos, overlaps, self.n_max_boxes)
-
-        # Assigned target
-        target_labels, target_bboxes, target_scores = self.get_targets(
-            gt_labels, gt_bboxes, target_gt_idx, fg_mask)
-
-        # normalize
-        align_metric *= mask_pos
-        pos_align_metrics = align_metric.amax(axis=-1, keepdim=True)  # b, max_num_obj
-        pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True)  # b, max_num_obj
-        norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
-        target_scores = target_scores * norm_align_metric
-
-        return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
-
-    def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points):
-        # get in_gts mask, (b, max_num_obj, h*w)
-        mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
-        # get anchor_align metric, (b, max_num_obj, h*w)
-        align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts)
-        # get topk_metric mask, (b, max_num_obj, h*w)
-        mask_topk = self.select_topk_candidates(align_metric)
-        # merge all mask to a final mask, (b, max_num_obj, h*w)
-        mask_pos = mask_topk * mask_in_gts
-
-        return mask_pos, align_metric, overlaps
-
-    def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts):
-        """Compute alignment metric given predicted and ground truth bounding boxes."""
-        na = pd_bboxes.shape[-2]
-        mask_in_gts = mask_in_gts.bool()  # b, max_num_obj, h*w
-        overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
-        bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
-
-        ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
-        ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes)  # b, max_num_obj
-        ind[1] = gt_labels.squeeze(-1)  # b, max_num_obj
-        # Get the scores of each grid for each gt cls
-        bbox_scores[mask_in_gts] = pd_scores[ind[0], :, ind[1]][mask_in_gts]  # b, max_num_obj, h*w
-
-        # (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
-        pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_in_gts]
-        gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_in_gts]
-        overlaps[mask_in_gts] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
-
-        align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
-        return align_metric, overlaps
-
-    def select_topk_candidates(self, metrics, largest=True):
-        """
-        Args:
-            metrics: (b, max_num_obj, h*w).
-            topk_mask: (b, max_num_obj, topk) or None
-        """
-        # (b, max_num_obj, topk)
-        topk_metrics, topk_idxs = torch.topk(metrics, self.topk_candidates, dim=-1, largest=largest)
-        topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
-        # (b, max_num_obj, topk)
-        topk_idxs.masked_fill_(~topk_mask, 0)
-
-        # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
-        count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
-        ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
-        for k in range(self.topk_candidates):
-            # Expand topk_idxs for each value of k and add 1 at the specified positions
-            count_tensor.scatter_add_(-1, topk_idxs[:, :, k:k + 1], ones)
-        # count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
-        # Filter invalid bboxes
-        count_tensor.masked_fill_(count_tensor > 1, 0)
-
-        return count_tensor.to(metrics.dtype)
-
-    def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
-        # Assigned target labels, (b, 1)
-        batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
-        target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes  # (b, h*w)
-        target_labels = gt_labels.long().flatten()[target_gt_idx]  # (b, h*w)
-
-        # Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
-        target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
-
-        # Assigned target scores
-        target_labels.clamp_(0)
-
-        # 10x faster than F.one_hot()
-        target_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.num_classes),
-                                    dtype=torch.int64,
-                                    device=target_labels.device)  # (b, h*w, 80)
-        target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
-
-        fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)  # (b, h*w, 80)
-        target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
-
-        return target_labels, target_bboxes, target_scores
-    
-
-# -------------------------- Basic Functions --------------------------
-def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
-    """select the positive anchors's center in gt
-    Args:
-        xy_centers (Tensor): shape(bs*n_max_boxes, num_total_anchors, 4)
-        gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
-    Return:
-        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-    """
-    n_anchors = xy_centers.size(0)
-    bs, n_max_boxes, _ = gt_bboxes.size()
-    _gt_bboxes = gt_bboxes.reshape([-1, 4])
-    xy_centers = xy_centers.unsqueeze(0).repeat(bs * n_max_boxes, 1, 1)
-    gt_bboxes_lt = _gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, n_anchors, 1)
-    gt_bboxes_rb = _gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, n_anchors, 1)
-    b_lt = xy_centers - gt_bboxes_lt
-    b_rb = gt_bboxes_rb - xy_centers
-    bbox_deltas = torch.cat([b_lt, b_rb], dim=-1)
-    bbox_deltas = bbox_deltas.reshape([bs, n_max_boxes, n_anchors, -1])
-    return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)
-
-def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
-    """if an anchor box is assigned to multiple gts,
-        the one with the highest iou will be selected.
-    Args:
-        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-        overlaps (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-    Return:
-        target_gt_idx (Tensor): shape(bs, num_total_anchors)
-        fg_mask (Tensor): shape(bs, num_total_anchors)
-        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-    """
-    fg_mask = mask_pos.sum(-2)
-    if fg_mask.max() > 1:  # one anchor is assigned to multiple gt_bboxes
-        mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1)  # (b, n_max_boxes, h*w)
-        max_overlaps_idx = overlaps.argmax(1)  # (b, h*w)
-
-        is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
-        is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
-
-        mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()  # (b, n_max_boxes, h*w)
-        fg_mask = mask_pos.sum(-2)
-    # Find each grid serve which gt(index)
-    target_gt_idx = mask_pos.argmax(-2)  # (b, h*w)
-
-    return target_gt_idx, fg_mask, mask_pos
-
-def iou_calculator(box1, box2, eps=1e-9):
-    """Calculate iou for batch
-    Args:
-        box1 (Tensor): shape(bs, n_max_boxes, 1, 4)
-        box2 (Tensor): shape(bs, 1, num_total_anchors, 4)
-    Return:
-        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-    """
-    box1 = box1.unsqueeze(2)  # [N, M1, 4] -> [N, M1, 1, 4]
-    box2 = box2.unsqueeze(1)  # [N, M2, 4] -> [N, 1, M2, 4]
-    px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
-    gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
-    x1y1 = torch.maximum(px1y1, gx1y1)
-    x2y2 = torch.minimum(px2y2, gx2y2)
-    overlap = (x2y2 - x1y1).clip(0).prod(-1)
-    area1 = (px2y2 - px1y1).clip(0).prod(-1)
-    area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
-    union = area1 + area2 - overlap + eps
-
-    return overlap / union

+ 0 - 0
yolo/models/yolof/build.py


+ 144 - 0
yolo/models/yolof/loss.py

@@ -0,0 +1,144 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from utils.box_ops import *
+from utils.misc import sigmoid_focal_loss
+from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
+
+from .matcher import UniformMatcher
+
+
+class SetCriterion(nn.Module):
+    """
+        This code referenced to https://github.com/megvii-model/YOLOF/blob/main/playground/detection/coco/yolof/yolof_base/yolof.py
+    """
+    def __init__(self, cfg):
+        super().__init__()
+        # ------------- Basic parameters -------------
+        self.cfg = cfg
+        self.num_classes = cfg.num_classes
+        # ------------- Focal loss -------------
+        self.alpha = cfg.focal_loss_alpha
+        self.gamma = cfg.focal_loss_gamma
+        # ------------- Loss weight -------------
+        self.weight_dict = {'loss_cls': cfg.loss_cls_weight,
+                            'loss_reg': cfg.loss_reg_weight}
+        # ------------- Matcher -------------
+        self.matcher_cfg = cfg.matcher_hpy
+        self.matcher = UniformMatcher(self.matcher_cfg['topk_candidates'])
+
+    def loss_labels(self, pred_cls, tgt_cls, num_boxes):
+        """
+            pred_cls: (Tensor) [N, C]
+            tgt_cls:  (Tensor) [N, C]
+        """
+        # cls loss: [V, C]
+        loss_cls = sigmoid_focal_loss(pred_cls, tgt_cls, self.alpha, self.gamma)
+
+        return loss_cls.sum() / num_boxes
+
+    def loss_bboxes(self, pred_box, tgt_box, num_boxes):
+        """
+            pred_box: (Tensor) [N, 4]
+            tgt_box:  (Tensor) [N, 4]
+        """
+        # giou
+        pred_giou = generalized_box_iou(pred_box, tgt_box)  # [N, M]
+        # giou loss
+        loss_reg = 1. - torch.diag(pred_giou)
+
+        return loss_reg.sum() / num_boxes
+
+    def forward(self, outputs, targets):
+        """
+            outputs['pred_cls']: (Tensor) [B, M, C]
+            outputs['pred_box']: (Tensor) [B, M, 4]
+            targets: (List) [dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}, ...]
+        """
+        # -------------------- Pre-process --------------------
+        pred_box = outputs['pred_box']
+        pred_cls = outputs['pred_cls'].reshape(-1, self.num_classes)
+        anchor_boxes = outputs['anchors']
+        masks = ~outputs['mask']
+        device = pred_box.device
+        B = len(targets)
+
+        # -------------------- Label assignment --------------------
+        indices = self.matcher(pred_box, anchor_boxes, targets)
+
+        # [M, 4] -> [1, M, 4] -> [B, M, 4]
+        anchor_boxes = box_cxcywh_to_xyxy(anchor_boxes)
+        anchor_boxes = anchor_boxes[None].repeat(B, 1, 1)
+
+        ious = []
+        pos_ious = []
+        for i in range(B):
+            src_idx, tgt_idx = indices[i]
+            # iou between predbox and tgt box
+            iou, _ = box_iou(pred_box[i, ...], (targets[i]['boxes']).clone())
+            if iou.numel() == 0:
+                max_iou = iou.new_full((iou.size(0),), 0)
+            else:
+                max_iou = iou.max(dim=1)[0]
+            # iou between anchorbox and tgt box
+            a_iou, _ = box_iou(anchor_boxes[i], (targets[i]['boxes']).clone())
+            if a_iou.numel() == 0:
+                pos_iou = a_iou.new_full((0,), 0)
+            else:
+                pos_iou = a_iou[src_idx, tgt_idx]
+            ious.append(max_iou)
+            pos_ious.append(pos_iou)
+
+        ious = torch.cat(ious)
+        ignore_idx = ious > self.matcher_cfg['ignore_thresh']
+        pos_ious = torch.cat(pos_ious)
+        pos_ignore_idx = pos_ious < self.matcher_cfg['iou_thresh']
+
+        src_idx = torch.cat(
+            [src + idx * anchor_boxes[0].shape[0] for idx, (src, _) in
+             enumerate(indices)])
+        # [BM,]
+        gt_cls = torch.full(pred_cls.shape[:1],
+                                self.num_classes,
+                                dtype=torch.int64,
+                                device=device)
+        gt_cls[ignore_idx] = -1
+        tgt_cls_o = torch.cat([t['labels'][J] for t, (_, J) in zip(targets, indices)])
+        tgt_cls_o[pos_ignore_idx] = -1
+
+        gt_cls[src_idx] = tgt_cls_o.to(device)
+
+        foreground_idxs = (gt_cls >= 0) & (gt_cls != self.num_classes)
+        num_foreground = foreground_idxs.sum()
+
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_foreground)
+        num_foreground = torch.clamp(num_foreground / get_world_size(), min=1).item()
+
+        # -------------------- Classification loss --------------------
+        gt_cls_target = torch.zeros_like(pred_cls)
+        gt_cls_target[foreground_idxs, gt_cls[foreground_idxs]] = 1
+        valid_idxs = (gt_cls >= 0) & masks
+        loss_labels = self.loss_labels(pred_cls[valid_idxs], gt_cls_target[valid_idxs], num_foreground)
+
+        # -------------------- Regression loss --------------------
+        tgt_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0).to(device)
+        tgt_boxes = tgt_boxes[~pos_ignore_idx]
+        matched_pred_box = pred_box.reshape(-1, 4)[src_idx[~pos_ignore_idx.cpu()]]
+        loss_bboxes = self.loss_bboxes(matched_pred_box, tgt_boxes, num_foreground)
+
+        total_loss = loss_labels * self.weight_dict["loss_cls"] + \
+                     loss_bboxes * self.weight_dict["loss_reg"]
+        loss_dict = dict(
+                loss_cls = loss_labels,
+                loss_reg = loss_bboxes,
+                losses   = total_loss,
+        )
+
+        return loss_dict
+
+
+if __name__ == "__main__":
+    pass

+ 103 - 0
yolo/models/yolof/matcher.py

@@ -0,0 +1,103 @@
+import numpy as np
+import torch
+from torch import nn
+from utils.box_ops import *
+
+
+class UniformMatcher(nn.Module):
+    """
+    This code is referenced to https://github.com/megvii-model/YOLOF/blob/main/playground/detection/coco/yolof/yolof_base/uniform_matcher.py
+    """
+    def __init__(self, match_times: int = 4):
+        super().__init__()
+        self.match_times = match_times
+
+    @torch.no_grad()
+    def forward(self, pred_boxes, anchor_boxes, targets):
+        """
+            pred_boxes:   (Tensor) -> [B, num_queries, 4]
+            anchor_boxes: (Tensor) -> [num_queries, 4]
+            targets:      (Dict)   -> dict{'boxes': [...], 'labels': [...]}
+        """
+
+        bs, num_queries = pred_boxes.shape[:2]
+
+        # We flatten to compute the cost matrices in a batch
+        # [B, num_queries, 4] -> [M, 4]
+        out_bbox = pred_boxes.flatten(0, 1)
+        # [num_queries, 4] -> [1, num_queries, 4] -> [B, num_queries, 4] -> [M, 4]
+        anchor_boxes = anchor_boxes[None].repeat(bs, 1, 1)
+        anchor_boxes = anchor_boxes.flatten(0, 1)
+
+        # Also concat the target boxes
+        tgt_bbox = torch.cat([v['boxes'] for v in targets])
+
+        # Compute the L1 cost between boxes
+        # Note that we use anchors and predict boxes both
+        cost_bbox = torch.cdist(box_xyxy_to_cxcywh(out_bbox), 
+                                box_xyxy_to_cxcywh(tgt_bbox), 
+                                p=1)
+        cost_bbox_anchors = torch.cdist(anchor_boxes, 
+                                        box_xyxy_to_cxcywh(tgt_bbox), 
+                                        p=1)
+
+        # Final cost matrix: [B, M, N], M=num_queries, N=num_tgt
+        C = cost_bbox
+        C = C.view(bs, num_queries, -1).cpu()
+        C1 = cost_bbox_anchors
+        C1 = C1.view(bs, num_queries, -1).cpu()
+
+        sizes = [len(v['boxes']) for v in targets]  # the number of object instances in each image
+        all_indices_list = [[] for _ in range(bs)]
+        # positive indices when matching predict boxes and gt boxes
+        # len(indices) = batch size
+        # len(tupe) = topk
+        indices = [
+            tuple(
+                torch.topk(
+                    c[i],
+                    k=self.match_times,
+                    dim=0,
+                    largest=False)[1].numpy().tolist()
+            )
+            for i, c in enumerate(C.split(sizes, -1))
+        ]
+        # positive indices when matching anchor boxes and gt boxes
+        indices1 = [
+            tuple(
+                torch.topk(
+                    c[i],
+                    k=self.match_times,
+                    dim=0,
+                    largest=False)[1].numpy().tolist())
+            for i, c in enumerate(C1.split(sizes, -1))]
+
+        # concat the indices according to image ids
+        # img_id = batch_id
+        for img_id, (idx, idx1) in enumerate(zip(indices, indices1)):
+            img_idx_i = [
+                np.array(idx_ + idx1_)
+                for (idx_, idx1_) in zip(idx, idx1)
+            ] # 'i' is the index of queris
+            img_idx_j = [
+                np.array(list(range(len(idx_))) + list(range(len(idx1_))))
+                for (idx_, idx1_) in zip(idx, idx1)
+            ] # 'j' is the index of tgt
+            all_indices_list[img_id] = [*zip(img_idx_i, img_idx_j)]
+
+        # re-organize the positive indices
+        all_indices = []
+        for img_id in range(bs):
+            all_idx_i = []
+            all_idx_j = []
+            for idx_list in all_indices_list[img_id]:
+                idx_i, idx_j = idx_list
+                all_idx_i.append(idx_i)
+                all_idx_j.append(idx_j)
+            all_idx_i = np.hstack(all_idx_i)
+            all_idx_j = np.hstack(all_idx_j)
+            all_indices.append((all_idx_i, all_idx_j))
+
+
+        return [(torch.as_tensor(i, dtype=torch.int64), 
+                 torch.as_tensor(j, dtype=torch.int64)) for i, j in all_indices]

+ 148 - 0
yolo/models/yolof/modules.py

@@ -0,0 +1,148 @@
+import torch
+import torch.nn as nn
+from typing import List
+
+
+# --------------------- Basic modules ---------------------
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+        
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+    elif norm_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+
+class BasicConv(nn.Module):
+    def __init__(self, 
+                 in_dim,                   # in channels
+                 out_dim,                  # out channels 
+                 kernel_size=1,            # kernel size 
+                 padding=0,                # padding
+                 stride=1,                 # padding
+                 dilation=1,               # dilation
+                 act_type  :str = 'lrelu', # activation
+                 norm_type :str = 'BN',    # normalization
+                 depthwise :bool = False
+                ):
+        super(BasicConv, self).__init__()
+        self.depthwise = depthwise
+        use_bias = False if norm_type is not None else True
+        if not depthwise:
+            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=1, bias=use_bias)
+            self.norm = get_norm(norm_type, out_dim)
+        else:
+            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim, bias=use_bias)
+            self.norm1 = get_norm(norm_type, in_dim)
+            self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1)
+            self.norm2 = get_norm(norm_type, out_dim)
+        self.act  = get_activation(act_type)
+
+    def forward(self, x):
+        if not self.depthwise:
+            return self.act(self.norm(self.conv(x)))
+        else:
+            # Depthwise conv
+            x = self.norm1(self.conv1(x))
+            # Pointwise conv
+            x = self.act(self.norm2(self.conv2(x)))
+            return x
+
+
+# --------------------- ResNet modules ---------------------
+def conv3x3(in_planes, out_planes, stride=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=1, bias=False)
+
+def conv1x1(in_planes, out_planes, stride=1):
+    """1x1 convolution"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(BasicBlock, self).__init__()
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(Bottleneck, self).__init__()
+        self.conv1 = conv1x1(inplanes, planes)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = conv3x3(planes, planes, stride)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.conv3 = conv1x1(planes, planes * self.expansion)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out

+ 187 - 0
yolo/models/yolof/resnet.py

@@ -0,0 +1,187 @@
+import torch
+import torch.nn as nn
+import torch.utils.model_zoo as model_zoo
+
+try:
+    from .modules import conv1x1, BasicBlock, Bottleneck
+except:
+    from  modules import conv1x1, BasicBlock, Bottleneck
+
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+           'resnet152']
+
+
+model_urls = {
+    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+}
+
+
+# --------------------- ResNet -----------------------
+class ResNet(nn.Module):
+
+    def __init__(self, block, layers, zero_init_residual=False):
+        super(ResNet, self).__init__()
+        self.inplanes = 64
+        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
+                               bias=False)
+        self.bn1 = nn.BatchNorm2d(64)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+        # Zero-initialize the last BN in each residual branch,
+        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+        if zero_init_residual:
+            for m in self.modules():
+                if isinstance(m, Bottleneck):
+                    nn.init.constant_(m.bn3.weight, 0)
+                elif isinstance(m, BasicBlock):
+                    nn.init.constant_(m.bn2.weight, 0)
+
+    def _make_layer(self, block, planes, blocks, stride=1):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                conv1x1(self.inplanes, planes * block.expansion, stride),
+                nn.BatchNorm2d(planes * block.expansion),
+            )
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(block(self.inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        """
+        Input:
+            x: (Tensor) -> [B, C, H, W]
+        Output:
+            c5: (Tensor) -> [B, C, H/32, W/32]
+        """
+        c1 = self.conv1(x)     # [B, C, H/2, W/2]
+        c1 = self.bn1(c1)      # [B, C, H/2, W/2]
+        c1 = self.relu(c1)     # [B, C, H/2, W/2]
+        c2 = self.maxpool(c1)  # [B, C, H/4, W/4]
+
+        c2 = self.layer1(c2)   # [B, C, H/4, W/4]
+        c3 = self.layer2(c2)   # [B, C, H/8, W/8]
+        c4 = self.layer3(c3)   # [B, C, H/16, W/16]
+        c5 = self.layer4(c4)   # [B, C, H/32, W/32]
+
+        return c5
+
+
+# --------------------- Functions -----------------------
+def build_resnet(model_name="resnet18", pretrained=False):
+    if model_name == 'resnet18':
+        model = resnet18(pretrained)
+        feat_dim = 512
+    elif model_name == 'resnet34':
+        model = resnet34(pretrained)
+        feat_dim = 512
+    elif model_name == 'resnet50':
+        model = resnet50(pretrained)
+        feat_dim = 2048
+    elif model_name == 'resnet101':
+        model = resnet34(pretrained)
+        feat_dim = 2048
+    else:
+        raise NotImplementedError("Unknown resnet: {}".format(model_name))
+    
+    return model, feat_dim
+
+def resnet18(pretrained=False, **kwargs):
+    """Constructs a ResNet-18 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
+    if pretrained:
+        # strict = False as we don't need fc layer params.
+        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False)
+    return model
+
+def resnet34(pretrained=False, **kwargs):
+    """Constructs a ResNet-34 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
+    if pretrained:
+        model.load_state_dict(model_zoo.load_url(model_urls['resnet34']), strict=False)
+    return model
+
+def resnet50(pretrained=False, **kwargs):
+    """Constructs a ResNet-50 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
+    if pretrained:
+        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)
+    return model
+
+def resnet101(pretrained=False, **kwargs):
+    """Constructs a ResNet-101 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
+    if pretrained:
+        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']), strict=False)
+    return model
+
+def resnet152(pretrained=False, **kwargs):
+    """Constructs a ResNet-152 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
+    if pretrained:
+        model.load_state_dict(model_zoo.load_url(model_urls['resnet152']), strict=False)
+    return model
+
+
+if __name__=='__main__':
+    import time
+    from thop import profile
+
+    # Build backbone
+    model, _ = build_resnet(model_name='resnet18')
+
+    # Inference
+    x = torch.randn(1, 3, 640, 640)
+    t0 = time.time()
+    output = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print(output.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))    

+ 0 - 0
yolo/models/yolof/yolof.py


+ 0 - 0
yolo/models/yolof/yolof_backbone.py


+ 185 - 0
yolo/models/yolof/yolof_decoder.py

@@ -0,0 +1,185 @@
+import math
+import torch
+import torch.nn as nn
+
+from .modules import BasicConv
+
+
+class YolofHead(nn.Module):
+    def __init__(self, cfg, in_dim, out_dim,):
+        super().__init__()
+        self.fmp_size = None
+        self.ctr_clamp = cfg.center_clamp
+        self.DEFAULT_EXP_CLAMP = math.log(1e8)
+        self.DEFAULT_SCALE_CLAMP = math.log(1000.0 / 16)
+        # ------------------ Basic parameters -------------------
+        self.cfg = cfg
+        self.in_dim = in_dim
+        self.stride       = cfg.out_stride
+        self.num_classes  = cfg.num_classes
+        self.num_cls_head = cfg.num_cls_head
+        self.num_reg_head = cfg.num_reg_head
+        self.act_type     = cfg.head_act
+        self.norm_type    = cfg.head_norm
+        # Anchor config
+        self.anchor_size = torch.as_tensor(cfg.anchor_size)
+        self.num_anchors = len(cfg.anchor_size)
+
+        # ------------------ Network parameters -------------------
+        ## cls head
+        cls_heads = []
+        self.cls_head_dim = out_dim
+        for i in range(self.num_cls_head):
+            if i == 0:
+                cls_heads.append(
+                    BasicConv(in_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=self.act_type, norm_type=self.norm_type)
+                              )
+            else:
+                cls_heads.append(
+                    BasicConv(self.cls_head_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=self.act_type, norm_type=self.norm_type)
+                              )
+        ## reg head
+        reg_heads = []
+        self.reg_head_dim = out_dim
+        for i in range(self.num_reg_head):
+            if i == 0:
+                reg_heads.append(
+                    BasicConv(in_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=self.act_type, norm_type=self.norm_type)
+                              )
+            else:
+                reg_heads.append(
+                    BasicConv(self.reg_head_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=self.act_type, norm_type=self.norm_type)
+                              )
+        self.cls_heads = nn.Sequential(*cls_heads)
+        self.reg_heads = nn.Sequential(*reg_heads)
+
+        # pred layer
+        self.obj_pred = nn.Conv2d(self.reg_head_dim, 1 * self.num_anchors, kernel_size=3, padding=1)
+        self.cls_pred = nn.Conv2d(self.cls_head_dim, self.num_classes * self.num_anchors, kernel_size=3, padding=1)
+        self.reg_pred = nn.Conv2d(self.reg_head_dim, 4 * self.num_anchors, kernel_size=3, padding=1)
+
+        # init bias
+        self._init_pred_layers()
+
+    def _init_pred_layers(self):  
+        # init cls pred
+        nn.init.normal_(self.cls_pred.weight, mean=0, std=0.01)
+        init_prob = 0.01
+        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+        nn.init.constant_(self.cls_pred.bias, bias_value)
+        # init reg pred
+        nn.init.normal_(self.reg_pred.weight, mean=0, std=0.01)
+        nn.init.constant_(self.reg_pred.bias, 0.0)
+        # init obj pred
+        nn.init.normal_(self.obj_pred.weight, mean=0, std=0.01)
+        nn.init.constant_(self.obj_pred.bias, 0.0)
+
+    def get_anchors(self, fmp_size):
+        """fmp_size: list -> [H, W] \n
+           stride: int -> output stride
+        """
+        # check anchor boxes
+        if self.fmp_size is not None and self.fmp_size == fmp_size:
+            return self.anchor_boxes
+        else:
+            # generate grid cells
+            fmp_h, fmp_w = fmp_size
+            anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+            # [H, W, 2] -> [HW, 2]
+            anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2) + 0.5
+            # [HW, 2] -> [HW, 1, 2] -> [HW, KA, 2] 
+            anchor_xy = anchor_xy[:, None, :].repeat(1, self.num_anchors, 1)
+            anchor_xy *= self.stride
+
+            # [KA, 2] -> [1, KA, 2] -> [HW, KA, 2]
+            anchor_wh = self.anchor_size[None, :, :].repeat(fmp_h*fmp_w, 1, 1)
+
+            # [HW, KA, 4] -> [M, 4]
+            anchor_boxes = torch.cat([anchor_xy, anchor_wh], dim=-1)
+            anchor_boxes = anchor_boxes.view(-1, 4)
+
+            self.anchor_boxes = anchor_boxes
+            self.fmp_size = fmp_size
+
+            return anchor_boxes
+        
+    def decode_boxes(self, anchor_boxes, pred_reg):
+        """
+            anchor_boxes: (List[tensor]) [1, M, 4]
+            pred_reg: (List[tensor]) [B, M, 4]
+        """
+        # x = x_anchor + dx * w_anchor
+        # y = y_anchor + dy * h_anchor
+        pred_ctr_offset = pred_reg[..., :2] * anchor_boxes[..., 2:]
+        pred_ctr_offset = torch.clamp(pred_ctr_offset, min=-self.ctr_clamp, max=self.ctr_clamp)
+        pred_ctr_xy = anchor_boxes[..., :2] + pred_ctr_offset
+
+        # w = w_anchor * exp(tw)
+        # h = h_anchor * exp(th)
+        pred_dwdh = pred_reg[..., 2:]
+        pred_dwdh = torch.clamp(pred_dwdh, max=self.DEFAULT_SCALE_CLAMP)
+        pred_wh = anchor_boxes[..., 2:] * pred_dwdh.exp()
+
+        # convert [x, y, w, h] -> [x1, y1, x2, y2]
+        pred_x1y1 = pred_ctr_xy - 0.5 * pred_wh
+        pred_x2y2 = pred_ctr_xy + 0.5 * pred_wh
+        pred_box = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
+        return pred_box
+
+    def forward(self, x, mask=None):
+        # ------------------- Decoupled head -------------------
+        cls_feats = self.cls_heads(x)
+        reg_feats = self.reg_heads(x)
+
+        # ------------------- Generate anchor box -------------------
+        fmp_size = cls_feats.shape[2:]
+        anchor_boxes = self.get_anchors(fmp_size)   # [M, 4]
+        anchor_boxes = anchor_boxes.to(cls_feats.device)
+
+        # ------------------- Predict -------------------
+        obj_pred = self.obj_pred(reg_feats)
+        cls_pred = self.cls_pred(cls_feats)
+        reg_pred = self.reg_pred(reg_feats)
+
+        # ------------------- Precoess preds -------------------
+        ## implicit objectness
+        B, _, H, W = obj_pred.size()
+        obj_pred = obj_pred.view(B, -1, 1, H, W)
+        cls_pred = cls_pred.view(B, -1, self.num_classes, H, W)
+        normalized_cls_pred = cls_pred + obj_pred - torch.log(
+                1. + 
+                torch.clamp(cls_pred, max=self.DEFAULT_EXP_CLAMP).exp() + 
+                torch.clamp(obj_pred, max=self.DEFAULT_EXP_CLAMP).exp())
+        # [B, KA, C, H, W] -> [B, H, W, KA, C] -> [B, M, C], M = HxWxKA
+        normalized_cls_pred = normalized_cls_pred.permute(0, 3, 4, 1, 2).contiguous()
+        normalized_cls_pred = normalized_cls_pred.view(B, -1, self.num_classes)
+        # [B, KA*4, H, W] -> [B, KA, 4, H, W] -> [B, H, W, KA, 4] -> [B, M, 4]
+        reg_pred = reg_pred.view(B, -1, 4, H, W).permute(0, 3, 4, 1, 2).contiguous()
+        reg_pred = reg_pred.view(B, -1, 4)
+        ## Decode bbox
+        box_pred = self.decode_boxes(anchor_boxes[None], reg_pred)  # [B, M, 4]
+        ## adjust mask
+        if mask is not None:
+            # [B, H, W]
+            mask = torch.nn.functional.interpolate(mask[None].float(), size=fmp_size).bool()[0]
+            # [B, H, W] -> [B, HW]
+            mask = mask.flatten(1)
+            # [B, HW] -> [B, HW, KA] -> [BM,], M= HW x KA
+            mask = mask[..., None].repeat(1, 1, self.num_anchors).flatten()
+
+        outputs = {"pred_cls": normalized_cls_pred,
+                   "pred_reg": reg_pred,
+                   "pred_box": box_pred,
+                   "anchors": anchor_boxes,
+                   "mask": mask}
+
+        return outputs 

+ 72 - 0
yolo/models/yolof/yolof_encoder.py

@@ -0,0 +1,72 @@
+import torch.nn as nn
+from utils import weight_init
+
+from .modules import BasicConv
+
+
+# BottleNeck
+class Bottleneck(nn.Module):
+    def __init__(self, in_dim, dilation, expand_ratio, act_type='relu', norm_type='BN'):
+        super(Bottleneck, self).__init__()
+        # ------------------ Basic parameters -------------------
+        self.in_dim = in_dim
+        self.dilation = dilation
+        self.expand_ratio = expand_ratio
+        inter_dim = round(in_dim * expand_ratio)
+        # ------------------ Network parameters -------------------
+        self.branch = nn.Sequential(
+            BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type),
+            BasicConv(inter_dim, inter_dim, kernel_size=3, padding=dilation, dilation=dilation, act_type=act_type, norm_type=norm_type),
+            BasicConv(inter_dim, in_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        )
+
+    def forward(self, x):
+        return x + self.branch(x)
+
+# Dilated Encoder
+class DilatedEncoder(nn.Module):
+    def __init__(self, cfg, in_dim, out_dim):
+        super(DilatedEncoder, self).__init__()
+        # ------------------ Basic parameters -------------------
+        self.in_dim = in_dim
+        self.out_dim = out_dim
+        self.expand_ratio = cfg.neck_expand_ratio
+        self.dilations    = cfg.neck_dilations
+        self.act_type     = cfg.neck_act
+        self.norm_type    = cfg.neck_norm
+        # ------------------ Network parameters -------------------
+        ## proj layer
+        self.projector = nn.Sequential(
+            BasicConv(in_dim, out_dim, kernel_size=1, act_type=None, norm_type=self.norm_type),
+            BasicConv(out_dim, out_dim, kernel_size=3, padding=1, act_type=None, norm_type=self.norm_type)
+        )
+        ## encoder layers
+        self.encoders = nn.Sequential(
+            *[Bottleneck(out_dim, d, self.expand_ratio, self.act_type, self.norm_type) for d in self.dilations])
+
+        self._init_weight()
+
+    def _init_weight(self):
+        for m in self.projector:
+            if isinstance(m, nn.Conv2d):
+                weight_init.c2_xavier_fill(m)
+                weight_init.c2_xavier_fill(m)
+            if isinstance(m, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+        for m in self.encoders.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.normal_(m.weight, mean=0, std=0.01)
+                if hasattr(m, 'bias') and m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+
+            if isinstance(m, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+    def forward(self, x):
+        x = self.projector(x)
+        x = self.encoders(x)
+
+        return x

+ 56 - 0
yolo/models/yolov10/README.md

@@ -0,0 +1,56 @@
+# YOLOv7:
+
+|    Model    |   Backbone    | Batch | Scale | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
+|-------------|---------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|
+| YOLOv7-Tiny | ELANNet-Tiny  | 8xb16 |  640  |         39.5           |       58.5        |   22.6            |   7.9              | [ckpt](https://github.com/yjh0410/RT-ODLab/releases/download/yolo_tutorial_ckpt/yolov7_tiny_coco.pth) |
+| YOLOv7      | ELANNet-Large | 8xb16 |  640  |         49.5           |       68.8        |   144.6           |   44.0             | [ckpt](https://github.com/yjh0410/RT-ODLab/releases/download/yolo_tutorial_ckpt/yolov7_coco.pth) |
+| YOLOv7-X    | ELANNet-Huge  |       |  640  |                        |                   |                   |                    |  |
+
+- For training, we train `YOLOv7` and `YOLOv7-Tiny` with 300 epochs on 8 GPUs.
+- For data augmentation, we use the [YOLOX-style](https://github.com/Megvii-BaseDetection/YOLOX) augmentation including the large scale jitter (LSJ), Mosaic augmentation and Mixup augmentation.
+- For optimizer, we use `AdamW` with weight decay 0.05 and per image learning rate 0.001 / 64.
+- For learning rate scheduler, we use Cosine decay scheduler.
+- For YOLOv7's structure, we replace the coupled head with the YOLOX-style decoupled head.
+- I think YOLOv7 uses too many training tricks, such as `anchor box`, `AuxiliaryHead`, `RepConv`, `Mosaic9x` and so on, making the picture of YOLO too complicated, which is against the development concept of the YOLO series. Otherwise, why don't we use the DETR series? It's nothing more than doing some acceleration optimization on DETR. Therefore, I was faithful to my own technical aesthetics and realized a cleaner and simpler YOLOv7, but without the blessing of so many tricks, I did not reproduce all the performance, which is a pity.
+- I have no more GPUs to train my `YOLOv7-X`.
+
+## Train YOLOv7
+### Single GPU
+Taking training YOLOv7-Tiny on COCO as the example,
+```Shell
+python train.py --cuda -d coco --root path/to/coco -m yolov7_tiny -bs 16 -size 640 --wp_epoch 3 --max_epoch 300 --eval_epoch 10 --no_aug_epoch 20 --ema --fp16 --multi_scale 
+```
+
+### Multi GPU
+Taking training YOLOv7-Tiny on COCO as the example,
+```Shell
+python -m torch.distributed.run --nproc_per_node=8 train.py --cuda -dist -d coco --root /data/datasets/ -m yolov7_tiny -bs 128 -size 640 --wp_epoch 3 --max_epoch 300  --eval_epoch 10 --no_aug_epoch 20 --ema --fp16 --sybn --multi_scale --save_folder weights/ 
+```
+
+## Test YOLOv7
+Taking testing YOLOv7-Tiny on COCO-val as the example,
+```Shell
+python test.py --cuda -d coco --root path/to/coco -m yolov7_tiny --weight path/to/yolov7_tiny.pth -size 640 -vt 0.4 --show 
+```
+
+## Evaluate YOLOv7
+Taking evaluating YOLOv7-Tiny on COCO-val as the example,
+```Shell
+python eval.py --cuda -d coco-val --root path/to/coco -m yolov7_tiny --weight path/to/yolov7_tiny.pth 
+```
+
+## Demo
+### Detect with Image
+```Shell
+python demo.py --mode image --path_to_img path/to/image_dirs/ --cuda -m yolov7_tiny --weight path/to/weight -size 640 -vt 0.4 --show
+```
+
+### Detect with Video
+```Shell
+python demo.py --mode video --path_to_vid path/to/video --cuda -m yolov7_tiny --weight path/to/weight -size 640 -vt 0.4 --show --gif
+```
+
+### Detect with Camera
+```Shell
+python demo.py --mode camera --cuda -m yolov7_tiny --weight path/to/weight -size 640 -vt 0.4 --show --gif
+```

+ 66 - 0
yolo/models/yolov10/build.py

@@ -0,0 +1,66 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+import torch
+import torch.nn as nn
+
+from .loss import build_criterion
+from .yolov10 import YOLOv7
+
+
+# build object detector
+def build_yolov7(args, cfg, device, num_classes=80, trainable=False, deploy=False):
+    print('==============================')
+    print('Build {} ...'.format(args.model.upper()))
+    
+    print('==============================')
+    print('Model Configuration: \n', cfg)
+    
+    # -------------- Build YOLO --------------
+    model = YOLOv7(cfg                = cfg,
+                   device             = device, 
+                   num_classes        = num_classes,
+                   trainable          = trainable,
+                   conf_thresh        = args.conf_thresh,
+                   nms_thresh         = args.nms_thresh,
+                   topk               = args.topk,
+                   deploy             = deploy,
+                   no_multi_labels    = args.no_multi_labels,
+                   nms_class_agnostic = args.nms_class_agnostic
+                   )
+
+    # -------------- Initialize YOLO --------------
+    for m in model.modules():
+        if isinstance(m, nn.BatchNorm2d):
+            m.eps = 1e-3
+            m.momentum = 0.03    
+    # Init bias
+    init_prob = 0.01
+    bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+    # obj pred
+    for obj_pred in model.obj_preds:
+        b = obj_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+    # cls pred
+    for cls_pred in model.cls_preds:
+        b = cls_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+    # reg pred
+    for reg_pred in model.reg_preds:
+        b = reg_pred.bias.view(-1, )
+        b.data.fill_(1.0)
+        reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        w = reg_pred.weight
+        w.data.fill_(0.)
+        reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
+
+
+    # -------------- Build criterion --------------
+    criterion = None
+    if trainable:
+        # build criterion for training
+        criterion = build_criterion(args, cfg, device, num_classes)
+
+    return model, criterion

+ 212 - 0
yolo/models/yolov10/loss.py

@@ -0,0 +1,212 @@
+import torch
+import torch.nn.functional as F
+from .matcher import SimOTA
+from utils.box_ops import get_ious
+from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
+
+
+
+class Criterion(object):
+    def __init__(self,
+                 args,
+                 cfg, 
+                 device, 
+                 num_classes=80):
+        self.args = args
+        self.cfg = cfg
+        self.device = device
+        self.num_classes = num_classes
+        self.max_epoch = args.max_epoch
+        self.no_aug_epoch = args.no_aug_epoch
+        self.aux_bbox_loss = False
+        # loss weight
+        self.loss_obj_weight = cfg['loss_obj_weight']
+        self.loss_cls_weight = cfg['loss_cls_weight']
+        self.loss_box_weight = cfg['loss_box_weight']
+        # matcher
+        matcher_config = cfg['matcher']
+        self.matcher = SimOTA(
+            num_classes=num_classes,
+            center_sampling_radius=matcher_config['center_sampling_radius'],
+            topk_candidate=matcher_config['topk_candicate']
+            )
+
+
+    def loss_objectness(self, pred_obj, gt_obj):
+        loss_obj = F.binary_cross_entropy_with_logits(pred_obj, gt_obj, reduction='none')
+
+        return loss_obj
+    
+
+    def loss_classes(self, pred_cls, gt_label):
+        loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_label, reduction='none')
+
+        return loss_cls
+
+
+    def loss_bboxes(self, pred_box, gt_box):
+        # regression loss
+        ious = get_ious(pred_box, gt_box, "xyxy", 'giou')
+        loss_box = 1.0 - ious
+
+        return loss_box
+
+
+    def loss_bboxes_aux(self, pred_reg, gt_box, anchors, stride_tensors):
+        # xyxy -> cxcy&bwbh
+        gt_cxcy = (gt_box[..., :2] + gt_box[..., 2:]) * 0.5
+        gt_bwbh = gt_box[..., 2:] - gt_box[..., :2]
+        # encode gt box
+        gt_cxcy_encode = (gt_cxcy - anchors) / stride_tensors
+        gt_bwbh_encode = torch.log(gt_bwbh / stride_tensors)
+        gt_box_encode = torch.cat([gt_cxcy_encode, gt_bwbh_encode], dim=-1)
+        # l1 loss
+        loss_box_aux = F.l1_loss(pred_reg, gt_box_encode, reduction='none')
+
+        return loss_box_aux
+
+
+    def __call__(self, outputs, targets, epoch=0):        
+        """
+            outputs['pred_obj']: List(Tensor) [B, M, 1]
+            outputs['pred_cls']: List(Tensor) [B, M, C]
+            outputs['pred_box']: List(Tensor) [B, M, 4]
+            outputs['pred_box']: List(Tensor) [B, M, 4]
+            outputs['strides']: List(Int) [8, 16, 32] output stride
+            targets: (List) [dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}, ...]
+        """
+        bs = outputs['pred_cls'][0].shape[0]
+        device = outputs['pred_cls'][0].device
+        fpn_strides = outputs['strides']
+        anchors = outputs['anchors']
+        # preds: [B, M, C]
+        obj_preds = torch.cat(outputs['pred_obj'], dim=1)
+        cls_preds = torch.cat(outputs['pred_cls'], dim=1)
+        box_preds = torch.cat(outputs['pred_box'], dim=1)
+
+        # label assignment
+        cls_targets = []
+        box_targets = []
+        obj_targets = []
+        fg_masks = []
+
+        for batch_idx in range(bs):
+            tgt_labels = targets[batch_idx]["labels"].to(device)
+            tgt_bboxes = targets[batch_idx]["boxes"].to(device)
+
+            # check target
+            if len(tgt_labels) == 0 or tgt_bboxes.max().item() == 0.:
+                num_anchors = sum([ab.shape[0] for ab in anchors])
+                # There is no valid gt
+                cls_target = obj_preds.new_zeros((0, self.num_classes))
+                box_target = obj_preds.new_zeros((0, 4))
+                obj_target = obj_preds.new_zeros((num_anchors, 1))
+                fg_mask = obj_preds.new_zeros(num_anchors).bool()
+            else:
+                (
+                    fg_mask,
+                    assigned_labels,
+                    assigned_ious,
+                    assigned_indexs
+                ) = self.matcher(
+                    fpn_strides = fpn_strides,
+                    anchors = anchors,
+                    pred_obj = obj_preds[batch_idx],
+                    pred_cls = cls_preds[batch_idx], 
+                    pred_box = box_preds[batch_idx],
+                    tgt_labels = tgt_labels,
+                    tgt_bboxes = tgt_bboxes
+                    )
+
+                obj_target = fg_mask.unsqueeze(-1)
+                cls_target = F.one_hot(assigned_labels.long(), self.num_classes)
+                cls_target = cls_target * assigned_ious.unsqueeze(-1)
+                box_target = tgt_bboxes[assigned_indexs]
+
+            cls_targets.append(cls_target)
+            box_targets.append(box_target)
+            obj_targets.append(obj_target)
+            fg_masks.append(fg_mask)
+
+        cls_targets = torch.cat(cls_targets, 0)
+        box_targets = torch.cat(box_targets, 0)
+        obj_targets = torch.cat(obj_targets, 0)
+        fg_masks = torch.cat(fg_masks, 0)
+        num_fgs = fg_masks.sum()
+
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_fgs)
+        num_fgs = (num_fgs / get_world_size()).clamp(1.0)
+
+        # ------------------ Objecntness loss ------------------
+        loss_obj = self.loss_objectness(obj_preds.view(-1, 1), obj_targets.float())
+        loss_obj = loss_obj.sum() / num_fgs
+        
+        # ------------------ Classification loss ------------------
+        cls_preds_pos = cls_preds.view(-1, self.num_classes)[fg_masks]
+        loss_cls = self.loss_classes(cls_preds_pos, cls_targets)
+        loss_cls = loss_cls.sum() / num_fgs
+
+        # ------------------ Regression loss ------------------
+        box_preds_pos = box_preds.view(-1, 4)[fg_masks]
+        loss_box = self.loss_bboxes(box_preds_pos, box_targets)
+        loss_box = loss_box.sum() / num_fgs
+
+        # total loss
+        losses = self.loss_obj_weight * loss_obj + \
+                 self.loss_cls_weight * loss_cls + \
+                 self.loss_box_weight * loss_box
+
+        # ------------------ Aux regression loss ------------------
+        loss_box_aux = None
+        if epoch >= (self.max_epoch - self.no_aug_epoch - 1):
+            ## reg_preds
+            reg_preds = torch.cat(outputs['pred_reg'], dim=1)
+            reg_preds_pos = reg_preds.view(-1, 4)[fg_masks]
+            ## anchor tensors
+            anchors_tensors = torch.cat(outputs['anchors'], dim=0)[None].repeat(bs, 1, 1)
+            anchors_tensors_pos = anchors_tensors.view(-1, 2)[fg_masks]
+            ## stride tensors
+            stride_tensors = torch.cat(outputs['stride_tensors'], dim=0)[None].repeat(bs, 1, 1)
+            stride_tensors_pos = stride_tensors.view(-1, 1)[fg_masks]
+            ## aux loss
+            loss_box_aux = self.loss_bboxes_aux(reg_preds_pos, box_targets, anchors_tensors_pos, stride_tensors_pos)
+            loss_box_aux = loss_box_aux.sum() / num_fgs
+
+            losses += loss_box_aux
+
+        # Loss dict
+        if loss_box_aux is None:
+            loss_dict = dict(
+                    loss_obj = loss_obj,
+                    loss_cls = loss_cls,
+                    loss_box = loss_box,
+                    losses = losses
+            )
+        else:
+            loss_dict = dict(
+                    loss_obj = loss_obj,
+                    loss_cls = loss_cls,
+                    loss_box = loss_box,
+                    loss_box_aux = loss_box_aux,
+                    losses = losses
+                    )
+
+        return loss_dict
+    
+
+def build_criterion(args, cfg, device, num_classes):
+    criterion = Criterion(
+        args=args,
+        cfg=cfg,
+        device=device,
+        num_classes=num_classes
+        )
+
+    return criterion
+
+
+if __name__ == "__main__":
+    pass

+ 187 - 0
yolo/models/yolov10/matcher.py

@@ -0,0 +1,187 @@
+# ---------------------------------------------------------------------
+# Copyright (c) Megvii Inc. All rights reserved.
+# ---------------------------------------------------------------------
+
+
+import torch
+import torch.nn.functional as F
+from utils.box_ops import *
+
+
+class SimOTA(object):
+    """
+        This code referenced to https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/models/yolo_head.py
+    """
+    def __init__(self, num_classes, center_sampling_radius, topk_candidate ):
+        self.num_classes = num_classes
+        self.center_sampling_radius = center_sampling_radius
+        self.topk_candidate = topk_candidate
+
+
+    @torch.no_grad()
+    def __call__(self, 
+                 fpn_strides, 
+                 anchors, 
+                 pred_obj, 
+                 pred_cls, 
+                 pred_box, 
+                 tgt_labels,
+                 tgt_bboxes):
+        # [M,]
+        strides_tensor = torch.cat([torch.ones_like(anchor_i[:, 0]) * stride_i
+                                for stride_i, anchor_i in zip(fpn_strides, anchors)], dim=-1)
+        # List[F, M, 2] -> [M, 2]
+        anchors = torch.cat(anchors, dim=0)
+        num_anchor = anchors.shape[0]        
+        num_gt = len(tgt_labels)
+
+        # ----------------------- Find inside points -----------------------
+        fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(
+            tgt_bboxes, anchors, strides_tensor, num_anchor, num_gt)
+        obj_preds = pred_obj[fg_mask].float()   # [Mp, 1]
+        cls_preds = pred_cls[fg_mask].float()   # [Mp, C]
+        box_preds = pred_box[fg_mask].float()   # [Mp, 4]
+
+        # ----------------------- Reg cost -----------------------
+        pair_wise_ious, _ = box_iou(tgt_bboxes, box_preds)      # [N, Mp]
+        reg_cost = -torch.log(pair_wise_ious + 1e-8)            # [N, Mp]
+
+        # ----------------------- Cls cost -----------------------
+        with torch.cuda.amp.autocast(enabled=False):
+            # [Mp, C]
+            score_preds = torch.sqrt(obj_preds.sigmoid_()* cls_preds.sigmoid_())
+            # [N, Mp, C]
+            score_preds = score_preds.unsqueeze(0).repeat(num_gt, 1, 1)
+            # prepare cls_target
+            cls_targets = F.one_hot(tgt_labels.long(), self.num_classes).float()
+            cls_targets = cls_targets.unsqueeze(1).repeat(1, score_preds.size(1), 1)
+            # [N, Mp]
+            cls_cost = F.binary_cross_entropy(score_preds, cls_targets, reduction="none").sum(-1)
+        del score_preds
+
+        #----------------------- Dynamic K-Matching -----------------------
+        cost_matrix = (
+            cls_cost
+            + 3.0 * reg_cost
+            + 100000.0 * (~is_in_boxes_and_center)
+        ) # [N, Mp]
+
+        (
+            assigned_labels,         # [num_fg,]
+            assigned_ious,           # [num_fg,]
+            assigned_indexs,         # [num_fg,]
+        ) = self.dynamic_k_matching(
+            cost_matrix,
+            pair_wise_ious,
+            tgt_labels,
+            num_gt,
+            fg_mask
+            )
+        del cls_cost, cost_matrix, pair_wise_ious, reg_cost
+
+        return fg_mask, assigned_labels, assigned_ious, assigned_indexs
+
+
+    def get_in_boxes_info(
+        self,
+        gt_bboxes,   # [N, 4]
+        anchors,     # [M, 2]
+        strides,     # [M,]
+        num_anchors, # M
+        num_gt,      # N
+        ):
+        # anchor center
+        x_centers = anchors[:, 0]
+        y_centers = anchors[:, 1]
+
+        # [M,] -> [1, M] -> [N, M]
+        x_centers = x_centers.unsqueeze(0).repeat(num_gt, 1)
+        y_centers = y_centers.unsqueeze(0).repeat(num_gt, 1)
+
+        # [N,] -> [N, 1] -> [N, M]
+        gt_bboxes_l = gt_bboxes[:, 0].unsqueeze(1).repeat(1, num_anchors) # x1
+        gt_bboxes_t = gt_bboxes[:, 1].unsqueeze(1).repeat(1, num_anchors) # y1
+        gt_bboxes_r = gt_bboxes[:, 2].unsqueeze(1).repeat(1, num_anchors) # x2
+        gt_bboxes_b = gt_bboxes[:, 3].unsqueeze(1).repeat(1, num_anchors) # y2
+
+        b_l = x_centers - gt_bboxes_l
+        b_r = gt_bboxes_r - x_centers
+        b_t = y_centers - gt_bboxes_t
+        b_b = gt_bboxes_b - y_centers
+        bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
+
+        is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
+        is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
+        # in fixed center
+        center_radius = self.center_sampling_radius
+
+        # [N, 2]
+        gt_centers = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) * 0.5
+        
+        # [1, M]
+        center_radius_ = center_radius * strides.unsqueeze(0)
+
+        gt_bboxes_l = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # x1
+        gt_bboxes_t = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # y1
+        gt_bboxes_r = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # x2
+        gt_bboxes_b = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # y2
+
+        c_l = x_centers - gt_bboxes_l
+        c_r = gt_bboxes_r - x_centers
+        c_t = y_centers - gt_bboxes_t
+        c_b = gt_bboxes_b - y_centers
+        center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
+        is_in_centers = center_deltas.min(dim=-1).values > 0.0
+        is_in_centers_all = is_in_centers.sum(dim=0) > 0
+
+        # in boxes and in centers
+        is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
+
+        is_in_boxes_and_center = (
+            is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
+        )
+        return is_in_boxes_anchor, is_in_boxes_and_center
+    
+    
+    def dynamic_k_matching(
+        self, 
+        cost, 
+        pair_wise_ious, 
+        gt_classes, 
+        num_gt, 
+        fg_mask
+        ):
+        # Dynamic K
+        # ---------------------------------------------------------------
+        matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
+
+        ious_in_boxes_matrix = pair_wise_ious
+        n_candidate_k = min(self.topk_candidate, ious_in_boxes_matrix.size(1))
+        topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
+        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
+        dynamic_ks = dynamic_ks.tolist()
+        for gt_idx in range(num_gt):
+            _, pos_idx = torch.topk(
+                cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
+            )
+            matching_matrix[gt_idx][pos_idx] = 1
+
+        del topk_ious, dynamic_ks, pos_idx
+
+        anchor_matching_gt = matching_matrix.sum(0)
+        if (anchor_matching_gt > 1).sum() > 0:
+            _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
+            matching_matrix[:, anchor_matching_gt > 1] *= 0
+            matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
+        fg_mask_inboxes = matching_matrix.sum(0) > 0
+
+        fg_mask[fg_mask.clone()] = fg_mask_inboxes
+
+        assigned_indexs = matching_matrix[:, fg_mask_inboxes].argmax(0)
+        assigned_labels = gt_classes[assigned_indexs]
+
+        assigned_ious = (matching_matrix * pair_wise_ious).sum(0)[
+            fg_mask_inboxes
+        ]
+        return assigned_labels, assigned_ious, assigned_indexs
+    

+ 338 - 0
yolo/models/yolov10/modules.py

@@ -0,0 +1,338 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+
+# ---------------------------- 2D CNN ----------------------------
+class SiLU(nn.Module):
+    """export-friendly version of nn.SiLU()"""
+
+    @staticmethod
+    def forward(x):
+        return x * torch.sigmoid(x)
+
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+
+## Basic conv layer
+class Conv(nn.Module):
+    def __init__(self, 
+                 c1,                   # in channels
+                 c2,                   # out channels 
+                 k=1,                  # kernel size 
+                 p=0,                  # padding
+                 s=1,                  # padding
+                 d=1,                  # dilation
+                 act_type='lrelu',     # activation
+                 norm_type='BN',       # normalization
+                 depthwise=False):
+        super(Conv, self).__init__()
+        convs = []
+        add_bias = False if norm_type else True
+        if depthwise:
+            convs.append(get_conv2d(c1, c1, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
+            # depthwise conv
+            if norm_type:
+                convs.append(get_norm(norm_type, c1))
+            if act_type:
+                convs.append(get_activation(act_type))
+            # pointwise conv
+            convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+
+        else:
+            convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+            
+        self.convs = nn.Sequential(*convs)
+
+
+    def forward(self, x):
+        return self.convs(x)
+
+
+# ---------------------------- YOLOv7 Modules ----------------------------
+## ELAN-Block proposed by YOLOv7
+class ELANBlock(nn.Module):
+    def __init__(self, in_dim, out_dim, squeeze_ratio=0.5, branch_depth :int=2, act_type='silu', norm_type='BN', depthwise=False):
+        super(ELANBlock, self).__init__()
+        inter_dim = int(in_dim * squeeze_ratio)
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv2 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv3 = nn.Sequential(*[
+            Conv(inter_dim, inter_dim, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+            for _ in range(round(branch_depth))
+        ])
+        self.cv4 = nn.Sequential(*[
+            Conv(inter_dim, inter_dim, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+            for _ in range(round(branch_depth))
+        ])
+
+        self.out = Conv(inter_dim*4, out_dim, k=1, act_type=act_type, norm_type=norm_type)
+
+
+
+    def forward(self, x):
+        x1 = self.cv1(x)
+        x2 = self.cv2(x)
+        x3 = self.cv3(x2)
+        x4 = self.cv4(x3)
+        out = self.out(torch.cat([x1, x2, x3, x4], dim=1))
+
+        return out
+
+## PaFPN's ELAN-Block proposed by YOLOv7
+class ELANBlockFPN(nn.Module):
+    def __init__(self, in_dim, out_dim, squeeze_ratio=0.5, branch_width :int=4, branch_depth :int=1, act_type='silu', norm_type='BN', depthwise=False):
+        super(ELANBlockFPN, self).__init__()
+        # Basic parameters
+        inter_dim = int(in_dim * squeeze_ratio)
+        inter_dim2 = int(inter_dim * squeeze_ratio) 
+        # Network structure
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv2 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv3 = nn.ModuleList()
+        for idx in range(round(branch_width)):
+            if idx == 0:
+                cvs = [Conv(inter_dim, inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)]
+            else:
+                cvs = [Conv(inter_dim2, inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)]
+            # deeper
+            if round(branch_depth) > 1:
+                for _ in range(1, round(branch_depth)):
+                    cvs.append(Conv(inter_dim2, inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise))
+                self.cv3.append(nn.Sequential(*cvs))
+            else:
+                self.cv3.append(cvs[0])
+
+        self.out = Conv(inter_dim*2+inter_dim2*len(self.cv3), out_dim, k=1, act_type=act_type, norm_type=norm_type)
+
+
+    def forward(self, x):
+        x1 = self.cv1(x)
+        x2 = self.cv2(x)
+        inter_outs = [x1, x2]
+        for m in self.cv3:
+            y1 = inter_outs[-1]
+            y2 = m(y1)
+            inter_outs.append(y2)
+        out = self.out(torch.cat(inter_outs, dim=1))
+
+        return out
+
+## DownSample Block proposed by YOLOv7
+class DownSample(nn.Module):
+    def __init__(self, in_dim, out_dim, act_type='silu', norm_type='BN', depthwise=False):
+        super().__init__()
+        inter_dim = out_dim // 2
+        self.mp = nn.MaxPool2d((2, 2), 2)
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv2 = nn.Sequential(
+            Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type),
+            Conv(inter_dim, inter_dim, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+
+    def forward(self, x):
+        x1 = self.cv1(self.mp(x))
+        x2 = self.cv2(x)
+        out = torch.cat([x1, x2], dim=1)
+
+        return out
+
+
+# ---------------------------- RepConv Modules ----------------------------
+class RepConv(nn.Module):
+    """
+        The code referenced to https://github.com/WongKinYiu/yolov7/models/common.py
+    """
+    # Represented convolution
+    # https://arxiv.org/abs/2101.03697
+
+    def __init__(self, c1, c2, k=3, s=1, p=1, g=1, act_type='silu', deploy=False):
+        super(RepConv, self).__init__()
+        # -------------- Basic parameters --------------
+        self.deploy = deploy
+        self.groups = g
+        self.in_channels = c1
+        self.out_channels = c2
+
+        # -------------- Network parameters --------------
+        if deploy:
+            self.rbr_reparam = nn.Conv2d(c1, c2, k, s, p, groups=g, bias=True)
+
+        else:
+            self.rbr_identity = (nn.BatchNorm2d(num_features=c1) if c2 == c1 and s == 1 else None)
+
+            self.rbr_dense = nn.Sequential(
+                nn.Conv2d(c1, c2, k, s, p, groups=g, bias=False),
+                nn.BatchNorm2d(num_features=c2),
+            )
+
+            self.rbr_1x1 = nn.Sequential(
+                nn.Conv2d(c1, c2, kernel_size=1, stride=s, bias=False),
+                nn.BatchNorm2d(num_features=c2),
+            )
+        self.act = get_activation(act_type)
+
+
+    def forward(self, inputs):
+        if hasattr(self, "rbr_reparam"):
+            return self.act(self.rbr_reparam(inputs))
+
+        if self.rbr_identity is None:
+            id_out = 0
+        else:
+            id_out = self.rbr_identity(inputs)
+
+        return self.act(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
+    
+    def get_equivalent_kernel_bias(self):
+        kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
+        kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
+        kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
+        return (
+            kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid,
+            bias3x3 + bias1x1 + biasid,
+        )
+
+    def _pad_1x1_to_3x3_tensor(self, kernel1x1):
+        if kernel1x1 is None:
+            return 0
+        else:
+            return nn.functional.pad(kernel1x1, [1, 1, 1, 1])
+
+    def _fuse_bn_tensor(self, branch):
+        if branch is None:
+            return 0, 0
+        if isinstance(branch, nn.Sequential):
+            kernel = branch[0].weight
+            running_mean = branch[1].running_mean
+            running_var = branch[1].running_var
+            gamma = branch[1].weight
+            beta = branch[1].bias
+            eps = branch[1].eps
+        else:
+            assert isinstance(branch, nn.BatchNorm2d)
+            if not hasattr(self, "id_tensor"):
+                input_dim = self.in_channels // self.groups
+                kernel_value = np.zeros(
+                    (self.in_channels, input_dim, 3, 3), dtype=np.float32
+                )
+                for i in range(self.in_channels):
+                    kernel_value[i, i % input_dim, 1, 1] = 1
+                self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
+            kernel = self.id_tensor
+            running_mean = branch.running_mean
+            running_var = branch.running_var
+            gamma = branch.weight
+            beta = branch.bias
+            eps = branch.eps
+        std = (running_var + eps).sqrt()
+        t = (gamma / std).reshape(-1, 1, 1, 1)
+        return kernel * t, beta - running_mean * gamma / std
+
+    def repvgg_convert(self):
+        kernel, bias = self.get_equivalent_kernel_bias()
+        return (
+            kernel.detach().cpu().numpy(),
+            bias.detach().cpu().numpy(),
+        )
+
+    def fuse_conv_bn(self, conv, bn):
+
+        std = (bn.running_var + bn.eps).sqrt()
+        bias = bn.bias - bn.running_mean * bn.weight / std
+
+        t = (bn.weight / std).reshape(-1, 1, 1, 1)
+        weights = conv.weight * t
+
+        bn = nn.Identity()
+        conv = nn.Conv2d(in_channels = conv.in_channels,
+                              out_channels = conv.out_channels,
+                              kernel_size = conv.kernel_size,
+                              stride=conv.stride,
+                              padding = conv.padding,
+                              dilation = conv.dilation,
+                              groups = conv.groups,
+                              bias = True,
+                              padding_mode = conv.padding_mode)
+
+        conv.weight = torch.nn.Parameter(weights)
+        conv.bias = torch.nn.Parameter(bias)
+        return conv
+
+    def fuse_repvgg_block(self):    
+        if self.deploy:
+            return
+                
+        self.rbr_dense = self.fuse_conv_bn(self.rbr_dense[0], self.rbr_dense[1])
+        
+        self.rbr_1x1 = self.fuse_conv_bn(self.rbr_1x1[0], self.rbr_1x1[1])
+        rbr_1x1_bias = self.rbr_1x1.bias
+        weight_1x1_expanded = torch.nn.functional.pad(self.rbr_1x1.weight, [1, 1, 1, 1])
+        
+        # Fuse self.rbr_identity
+        if (isinstance(self.rbr_identity, nn.BatchNorm2d) or isinstance(self.rbr_identity, nn.modules.batchnorm.SyncBatchNorm)):
+            identity_conv_1x1 = nn.Conv2d(
+                    in_channels=self.in_channels,
+                    out_channels=self.out_channels,
+                    kernel_size=1,
+                    stride=1,
+                    padding=0,
+                    groups=self.groups, 
+                    bias=False)
+            identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.to(self.rbr_1x1.weight.data.device)
+            identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.squeeze().squeeze()
+
+            identity_conv_1x1.weight.data.fill_(0.0)
+            identity_conv_1x1.weight.data.fill_diagonal_(1.0)
+            identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.unsqueeze(2).unsqueeze(3)
+
+            identity_conv_1x1 = self.fuse_conv_bn(identity_conv_1x1, self.rbr_identity)
+            bias_identity_expanded = identity_conv_1x1.bias
+            weight_identity_expanded = torch.nn.functional.pad(identity_conv_1x1.weight, [1, 1, 1, 1])            
+        else:
+            bias_identity_expanded = torch.nn.Parameter( torch.zeros_like(rbr_1x1_bias) )
+            weight_identity_expanded = torch.nn.Parameter( torch.zeros_like(weight_1x1_expanded) )            
+        
+        self.rbr_dense.weight = torch.nn.Parameter(self.rbr_dense.weight + weight_1x1_expanded + weight_identity_expanded)
+        self.rbr_dense.bias = torch.nn.Parameter(self.rbr_dense.bias + rbr_1x1_bias + bias_identity_expanded)
+                
+        self.rbr_reparam = self.rbr_dense
+        self.deploy = True
+
+        if self.rbr_identity is not None:
+            del self.rbr_identity
+            self.rbr_identity = None
+
+        if self.rbr_1x1 is not None:
+            del self.rbr_1x1
+            self.rbr_1x1 = None
+
+        if self.rbr_dense is not None:
+            del self.rbr_dense
+            self.rbr_dense = None

+ 302 - 0
yolo/models/yolov10/yolov10.py

@@ -0,0 +1,302 @@
+import torch
+import torch.nn as nn
+
+from utils.misc import multiclass_nms
+
+from .yolov10_backbone import build_backbone
+from .yolov10_neck import build_neck
+from .yolov10_pafpn import build_fpn
+from .yolov10_head import build_head
+
+
+# YOLOv7
+class YOLOv7(nn.Module):
+    def __init__(self,
+                 cfg,
+                 device,
+                 num_classes=20,
+                 conf_thresh=0.01,
+                 topk=100,
+                 nms_thresh=0.5,
+                 trainable=False,
+                 deploy = False,
+                 no_multi_labels = False,
+                 nms_class_agnostic = False):
+        super(YOLOv7, self).__init__()
+        # ------------------- Basic parameters -------------------
+        self.cfg = cfg                                 # 模型配置文件
+        self.device = device                           # cuda或者是cpu
+        self.num_classes = num_classes                 # 类别的数量
+        self.trainable = trainable                     # 训练的标记
+        self.conf_thresh = conf_thresh                 # 得分阈值
+        self.nms_thresh = nms_thresh                   # NMS阈值
+        self.topk_candidates = topk                    # topk
+        self.stride = [8, 16, 32]                      # 网络的输出步长
+        self.num_levels = 3
+        self.deploy = deploy
+        self.no_multi_labels = no_multi_labels
+        self.nms_class_agnostic = nms_class_agnostic
+        # ------------------- Network Structure -------------------
+        ## 主干网络
+        self.backbone, feats_dim = build_backbone(cfg, trainable&cfg['pretrained'])
+
+        ## 颈部网络: SPP模块
+        self.neck = build_neck(cfg, in_dim=feats_dim[-1], out_dim=feats_dim[-1]//2)
+        feats_dim[-1] = self.neck.out_dim
+
+        ## 颈部网络: 特征金字塔
+        self.fpn = build_fpn(cfg=cfg, in_dims=feats_dim, out_dim=round(256*cfg['channel_width']))
+        self.head_dim = self.fpn.out_dim
+
+        ## 检测头
+        self.non_shared_heads = nn.ModuleList(
+            [build_head(cfg, head_dim, head_dim, num_classes) 
+            for head_dim in self.head_dim
+            ])
+
+        ## 预测层
+        self.obj_preds = nn.ModuleList(
+                            [nn.Conv2d(head.reg_out_dim, 1, kernel_size=1) 
+                                for head in self.non_shared_heads
+                              ]) 
+        self.cls_preds = nn.ModuleList(
+                            [nn.Conv2d(head.cls_out_dim, self.num_classes, kernel_size=1) 
+                                for head in self.non_shared_heads
+                              ]) 
+        self.reg_preds = nn.ModuleList(
+                            [nn.Conv2d(head.reg_out_dim, 4, kernel_size=1) 
+                                for head in self.non_shared_heads
+                              ])                 
+
+
+    # ---------------------- Basic Functions ----------------------
+    ## generate anchor points
+    def generate_anchors(self, level, fmp_size):
+        """
+            fmp_size: (List) [H, W]
+        """
+        # generate grid cells
+        fmp_h, fmp_w = fmp_size
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        # [H, W, 2] -> [HW, 2]
+        anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
+        anchor_xy += 0.5  # add center offset
+        anchor_xy *= self.stride[level]
+        anchors = anchor_xy.to(self.device)
+
+        return anchors
+        
+    ## post-process
+    def post_process(self, obj_preds, cls_preds, box_preds):
+        """
+        Input:
+            cls_preds: List[np.array] -> [[M, C], ...]
+            box_preds: List[np.array] -> [[M, 4], ...]
+            obj_preds: List[np.array] -> [[M, 1], ...] or None
+        Output:
+            bboxes: np.array -> [N, 4]
+            scores: np.array -> [N,]
+            labels: np.array -> [N,]
+        """
+        assert len(cls_preds) == self.num_levels
+        all_scores = []
+        all_labels = []
+        all_bboxes = []
+        
+        for obj_pred_i, cls_pred_i, box_pred_i in zip(obj_preds, cls_preds, box_preds):
+            if self.no_multi_labels:
+                # [M,]
+                scores, labels = torch.max(torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid()), dim=1)
+
+                # Keep top k top scoring indices only.
+                num_topk = min(self.topk_candidates, box_pred_i.size(0))
+
+                # topk candidates
+                predicted_prob, topk_idxs = scores.sort(descending=True)
+                topk_scores = predicted_prob[:num_topk]
+                topk_idxs = topk_idxs[:num_topk]
+
+                # filter out the proposals with low confidence score
+                keep_idxs = topk_scores > self.conf_thresh
+                scores = topk_scores[keep_idxs]
+                topk_idxs = topk_idxs[keep_idxs]
+
+                labels = labels[topk_idxs]
+                bboxes = box_pred_i[topk_idxs]
+
+            else:
+                # [M, C] -> [MC,]
+                scores_i = (torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid())).flatten()
+
+                # Keep top k top scoring indices only.
+                num_topk = min(self.topk_candidates, box_pred_i.size(0))
+
+                # torch.sort is actually faster than .topk (at least on GPUs)
+                predicted_prob, topk_idxs = scores_i.sort(descending=True)
+                topk_scores = predicted_prob[:num_topk]
+                topk_idxs = topk_idxs[:num_topk]
+
+                # filter out the proposals with low confidence score
+                keep_idxs = topk_scores > self.conf_thresh
+                scores = topk_scores[keep_idxs]
+                topk_idxs = topk_idxs[keep_idxs]
+
+                anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+                labels = topk_idxs % self.num_classes
+
+                bboxes = box_pred_i[anchor_idxs]
+
+            all_scores.append(scores)
+            all_labels.append(labels)
+            all_bboxes.append(bboxes)
+
+        scores = torch.cat(all_scores)
+        labels = torch.cat(all_labels)
+        bboxes = torch.cat(all_bboxes)
+
+        # to cpu & numpy
+        scores = scores.cpu().numpy()
+        labels = labels.cpu().numpy()
+        bboxes = bboxes.cpu().numpy()
+
+        # nms
+        scores, labels, bboxes = multiclass_nms(
+            scores, labels, bboxes, self.nms_thresh, self.num_classes, self.nms_class_agnostic)
+
+        return bboxes, scores, labels
+    
+
+    # ---------------------- Main Process for Inference ----------------------
+    @torch.no_grad()
+    def inference_single_image(self, x):
+        # 主干网络
+        pyramid_feats = self.backbone(x)
+
+        # 颈部网络
+        pyramid_feats[-1] = self.neck(pyramid_feats[-1])
+
+        # 特征金字塔
+        pyramid_feats = self.fpn(pyramid_feats)
+
+        # 检测头
+        all_obj_preds = []
+        all_cls_preds = []
+        all_box_preds = []
+        all_anchors = []
+        for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
+            cls_feat, reg_feat = head(feat)
+
+            # [1, C, H, W]
+            obj_pred = self.obj_preds[level](reg_feat)
+            cls_pred = self.cls_preds[level](cls_feat)
+            reg_pred = self.reg_preds[level](reg_feat)
+
+            # anchors: [M, 2]
+            fmp_size = cls_pred.shape[-2:]
+            anchors = self.generate_anchors(level, fmp_size)
+
+            # [1, C, H, W] -> [H, W, C] -> [M, C]
+            obj_pred = obj_pred[0].permute(1, 2, 0).contiguous().view(-1, 1)
+            cls_pred = cls_pred[0].permute(1, 2, 0).contiguous().view(-1, self.num_classes)
+            reg_pred = reg_pred[0].permute(1, 2, 0).contiguous().view(-1, 4)
+
+            # decode bbox
+            ctr_pred = reg_pred[..., :2] * self.stride[level] + anchors[..., :2]
+            wh_pred = torch.exp(reg_pred[..., 2:]) * self.stride[level]
+            pred_x1y1 = ctr_pred - wh_pred * 0.5
+            pred_x2y2 = ctr_pred + wh_pred * 0.5
+            box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
+            all_obj_preds.append(obj_pred)
+            all_cls_preds.append(cls_pred)
+            all_box_preds.append(box_pred)
+            all_anchors.append(anchors)
+
+        if self.deploy:
+            obj_preds = torch.cat(all_obj_preds, dim=0)
+            cls_preds = torch.cat(all_cls_preds, dim=0)
+            box_preds = torch.cat(all_box_preds, dim=0)
+            scores = torch.sqrt(obj_preds.sigmoid() * cls_preds.sigmoid())
+            bboxes = box_preds
+            # [n_anchors_all, 4 + C]
+            outputs = torch.cat([bboxes, scores], dim=-1)
+
+        else:
+            # post process
+            bboxes, scores, labels = self.post_process(
+                all_obj_preds, all_cls_preds, all_box_preds)
+            outputs = {
+                "scores": scores,
+                "labels": labels,
+                "bboxes": bboxes
+            }
+
+        return outputs
+
+    # ---------------------- Main Process for Training ----------------------
+    def forward(self, x):
+        if not self.trainable:
+            return self.inference_single_image(x)
+        else:
+            # 主干网络
+            pyramid_feats = self.backbone(x)
+
+            # 颈部网络
+            pyramid_feats[-1] = self.neck(pyramid_feats[-1])
+
+            # 特征金字塔
+            pyramid_feats = self.fpn(pyramid_feats)
+
+            # 检测头
+            all_anchors = []
+            all_strides = []
+            all_obj_preds = []
+            all_cls_preds = []
+            all_box_preds = []
+            all_reg_preds = []
+            for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
+                cls_feat, reg_feat = head(feat)
+
+                # [B, C, H, W]
+                obj_pred = self.obj_preds[level](reg_feat)
+                cls_pred = self.cls_preds[level](cls_feat)
+                reg_pred = self.reg_preds[level](reg_feat)
+
+                B, _, H, W = cls_pred.size()
+                fmp_size = [H, W]
+                # generate anchor boxes: [M, 4]
+                anchors = self.generate_anchors(level, fmp_size)
+                
+                # stride tensor: [M, 1]
+                stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride[level]
+
+                # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
+                obj_pred = obj_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 1)
+                cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+                reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
+
+                # decode bbox
+                ctr_pred = reg_pred[..., :2] * self.stride[level] + anchors[..., :2]
+                wh_pred = torch.exp(reg_pred[..., 2:]) * self.stride[level]
+                pred_x1y1 = ctr_pred - wh_pred * 0.5
+                pred_x2y2 = ctr_pred + wh_pred * 0.5
+                box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
+                all_obj_preds.append(obj_pred)
+                all_cls_preds.append(cls_pred)
+                all_box_preds.append(box_pred)
+                all_reg_preds.append(reg_pred)
+                all_anchors.append(anchors)
+                all_strides.append(stride_tensor)
+            
+            # output dict
+            outputs = {"pred_obj": all_obj_preds,        # List(Tensor) [B, M, 1]
+                       "pred_cls": all_cls_preds,        # List(Tensor) [B, M, C]
+                       "pred_box": all_box_preds,        # List(Tensor) [B, M, 4]
+                       "pred_reg": all_reg_preds,        # List(Tensor) [B, M, 4]
+                       "anchors": all_anchors,           # List(Tensor) [M, 2]
+                       "strides": self.stride,           # List(Int) [8, 16, 32]
+                       "stride_tensors": all_strides     # List(Tensor) [M, 1]
+                       }
+
+            return outputs 

部分文件因为文件数量过多而无法显示