Răsfoiți Sursa

add Lightweight Object Detector(LODet)

yjh0410 2 ani în urmă
părinte
comite
d27a075efa

+ 1 - 1
README.md

@@ -153,7 +153,7 @@ python train.py --cuda -d coco --root path/to/COCO -m yolov1 -bs 16 --max_epoch
 
 - We use `AdamW` optimizer with `per_image_lr=0.001 / 64` and `linear` learning rate decay scheduler to train all models with 300 epoch.
 - We use `YOLOv5-style Mosaic augmentation` and `YOLOX-style Mixup augmentation` wihout rotation.
-- Due to my limited computing resources, I can not to train `YOLOvx-X` with the setting of `batch size=128`.
+- Due to my limited computing resources, I can not train `YOLOvx-X` with the setting of `batch size=128`.
 
 #### Redesigned RT-DETR:
 

+ 5 - 0
config/__init__.py

@@ -106,6 +106,8 @@ from .model_config.yolov5_config import yolov5_cfg
 from .model_config.yolov7_config import yolov7_cfg
 from .model_config.yolovx_config import yolovx_cfg
 from .model_config.yolox_config import yolox_cfg
+## LODet
+from .model_config.lodet_config import lodet_cfg
 ## Real-Time DETR
 from .model_config.rtdetr_config import rtdetr_cfg
 
@@ -137,6 +139,9 @@ def build_model_config(args):
     # YOLOvx
     elif args.model in ['yolovx_n', 'yolovx_s', 'yolovx_m', 'yolovx_l', 'yolovx_x']:
         cfg = yolovx_cfg[args.model]
+    # LODet
+    elif args.model == 'lodet':
+        cfg = lodet_cfg
     # RT-DETR
     elif args.model in ['rtdetr_n', 'rtdetr_s', 'rtdetr_m', 'rtdetr_l', 'rtdetr_x']:
         cfg = rtdetr_cfg[args.model]

+ 53 - 0
config/model_config/lodet_config.py

@@ -0,0 +1,53 @@
+# Light-weight Object Detector Config
+
+
+lodet_cfg = {
+    # ---------------- Model config ----------------
+    ## Backbone
+    'backbone': 'smnet',
+    'pretrained': True,
+    'bk_act': 'silu',
+    'bk_norm': 'BN',
+    'bk_dpw': True,
+    'stride': [8, 16, 32],  # P3, P4, P5
+    'max_stride': 32,
+    ## Neck: SPP
+    'neck': 'sppf',
+    'neck_expand_ratio': 0.5,
+    'pooling_size': 5,
+    'neck_act': 'silu',
+    'neck_norm': 'BN',
+    'neck_depthwise': True,
+    ## Neck: PaFPN
+    'fpn': 'lodet_pafpn',
+    'fpn_reduce_layer': 'conv',
+    'fpn_downsample_layer': 'maxpool',
+    'fpn_core_block': 'smblock',
+    'fpn_expand_ratio': 0.5,
+    'fpn_act': 'silu',
+    'fpn_norm': 'BN',
+    'fpn_depthwise': True,
+    ## Head
+    'head': 'decoupled_head',
+    'head_act': 'silu',
+    'head_norm': 'BN',
+    'num_cls_head': 2,
+    'num_reg_head': 2,
+    'head_depthwise': True,
+    'reg_max': 16,
+    # ---------------- Train config ----------------
+    ## Input
+    'multi_scale': [0.5, 1.25],   # 320 -> 800
+    'trans_type': 'yolovx_pico',
+    # ---------------- Assignment config ----------------
+    ## Matcher
+    'matcher': {'center_sampling_radius': 2.5,
+                'topk_candicate': 10},
+    # ---------------- Loss config ----------------
+    ## Loss weight
+    'loss_cls_weight': 1.0,
+    'loss_box_weight': 5.0,
+    'loss_dfl_weight': 1.0,
+    # ---------------- Train config ----------------
+    'trainer_type': 'rtmdet',
+}

+ 5 - 5
config/model_config/yolox_config.py

@@ -43,7 +43,7 @@ yolox_cfg = {
         'loss_cls_weight': 1.0,
         'loss_box_weight': 5.0,
         # ---------------- Train config ----------------
-        'trainer_type': 'rtmdet',
+        'trainer_type': 'yolox',
     },
 
     'yolox_s':{
@@ -87,7 +87,7 @@ yolox_cfg = {
         'loss_cls_weight': 1.0,
         'loss_box_weight': 5.0,
         # ---------------- Train config ----------------
-        'trainer_type': 'rtmdet',
+        'trainer_type': 'yolox',
     },
 
     'yolox_m':{
@@ -131,7 +131,7 @@ yolox_cfg = {
         'loss_cls_weight': 1.0,
         'loss_box_weight': 5.0,
         # ---------------- Train config ----------------
-        'trainer_type': 'rtmdet',
+        'trainer_type': 'yolox',
     },
 
     'yolox_l':{
@@ -175,7 +175,7 @@ yolox_cfg = {
         'loss_cls_weight': 1.0,
         'loss_box_weight': 5.0,
         # ---------------- Train config ----------------
-        'trainer_type': 'rtmdet',
+        'trainer_type': 'yolox',
     },
 
     'yolox_x':{
@@ -219,7 +219,7 @@ yolox_cfg = {
         'loss_cls_weight': 1.0,
         'loss_box_weight': 5.0,
         # ---------------- Train config ----------------
-        'trainer_type': 'rtmdet',
+        'trainer_type': 'yolox',
     },
 
 }

+ 5 - 0
models/detectors/__init__.py

@@ -10,6 +10,7 @@ from .yolov5.build import build_yolov5
 from .yolov7.build import build_yolov7
 from .yolovx.build import build_yolovx
 from .yolox.build import build_yolox
+from .lodet.build import build_lodet
 from .rtdetr.build import build_rtdetr
 
 
@@ -52,6 +53,10 @@ def build_model(args,
     elif args.model in ['yolox_n', 'yolox_s', 'yolox_m', 'yolox_l', 'yolox_x']:
         model, criterion = build_yolox(
             args, model_cfg, device, num_classes, trainable, deploy)
+    # LODet
+    elif args.model == 'lodet':
+        model, criterion = build_lodet(
+            args, model_cfg, device, num_classes, trainable, deploy)
     # RT-DETR
     elif args.model in ['rtdetr_n', 'rtdetr_s', 'rtdetr_m', 'rtdetr_l', 'rtdetr_x']:
         model, criterion = build_rtdetr(

+ 39 - 0
models/detectors/lodet/build.py

@@ -0,0 +1,39 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+import torch
+import torch.nn as nn
+
+from .loss import build_criterion
+from .lodet import LODet
+
+
+# build object detector
+def build_lodet(args, cfg, device, num_classes=80, trainable=False, deploy=False):
+    print('==============================')
+    print('Build {} ...'.format(args.model.upper()))
+        
+    # -------------- Build LODet --------------
+    model = LODet(
+        cfg=cfg,
+        device=device, 
+        num_classes=num_classes,
+        trainable=trainable,
+        conf_thresh=args.conf_thresh,
+        nms_thresh=args.nms_thresh,
+        topk=args.topk,
+        deploy=deploy
+        )
+
+    # -------------- Initialize LODet --------------
+    for m in model.modules():
+        if isinstance(m, nn.BatchNorm2d):
+            m.eps = 1e-3
+            m.momentum = 0.03    
+            
+    # -------------- Build criterion --------------
+    criterion = None
+    if trainable:
+        # build criterion for training
+        criterion = build_criterion(cfg, device, num_classes)
+    return model, criterion

+ 176 - 0
models/detectors/lodet/lodet.py

@@ -0,0 +1,176 @@
+# --------------- Torch components ---------------
+import torch
+import torch.nn as nn
+
+# --------------- Model components ---------------
+from .lodet_backbone import build_backbone
+from .lodet_neck import build_neck
+from .lodet_pafpn import build_fpn
+from .lodet_head import build_det_head
+from .lodet_pred import build_pred_layer
+
+# --------------- External components ---------------
+from utils.misc import multiclass_nms
+
+
+# Lightweight Object Detector
+class LODet(nn.Module):
+    def __init__(self, 
+                 cfg,
+                 device, 
+                 num_classes = 20, 
+                 conf_thresh = 0.05,
+                 nms_thresh = 0.6,
+                 trainable = False, 
+                 topk = 1000,
+                 deploy = False):
+        super(LODet, self).__init__()
+        # ---------------------- Basic Parameters ----------------------
+        self.cfg = cfg
+        self.device = device
+        self.stride = cfg['stride']
+        self.reg_max = cfg['reg_max']
+        self.num_classes = num_classes
+        self.trainable = trainable
+        self.conf_thresh = conf_thresh
+        self.nms_thresh = nms_thresh
+        self.topk = topk
+        self.deploy = deploy
+        self.head_dim = 64
+        
+        # ---------------------- Network Parameters ----------------------
+        ## ----------- Backbone -----------
+        self.backbone, feats_dim = build_backbone(cfg, trainable&cfg['pretrained'])
+
+        ## ----------- Neck: SPP -----------
+        self.neck = build_neck(cfg, feats_dim[-1], feats_dim[-1])
+        feats_dim[-1] = self.neck.out_dim
+        
+        ## ----------- Neck: FPN -----------
+        self.fpn = build_fpn(cfg, feats_dim, self.head_dim)
+        self.fpn_dims = self.fpn.out_dim
+
+        ## ----------- Heads -----------
+        self.det_heads = build_det_head(
+            cfg, self.fpn_dims, self.head_dim, num_classes, num_levels=len(self.stride))
+
+        ## ----------- Preds -----------
+        self.pred_layers = build_pred_layer(
+            self.det_heads.cls_head_dim, self.det_heads.reg_head_dim,
+            self.stride, num_classes, num_coords=4, num_levels=len(self.stride))
+
+
+    ## post-process
+    def post_process(self, cls_preds, box_preds):
+        """
+        Input:
+            cls_preds: List(Tensor) [[H x W, C], ...]
+            box_preds: List(Tensor) [[H x W, 4], ...]
+        """
+        all_scores = []
+        all_labels = []
+        all_bboxes = []
+        
+        for cls_pred_i, box_pred_i in zip(cls_preds, box_preds):
+            cls_pred_i = cls_pred_i[0]
+            box_pred_i = box_pred_i[0]
+            
+            # (H x W x KA x C,)
+            scores_i = cls_pred_i.sigmoid().flatten()
+
+            # Keep top k top scoring indices only.
+            num_topk = min(self.topk, box_pred_i.size(0))
+
+            # torch.sort is actually faster than .topk (at least on GPUs)
+            predicted_prob, topk_idxs = scores_i.sort(descending=True)
+            topk_scores = predicted_prob[:num_topk]
+            topk_idxs = topk_idxs[:num_topk]
+
+            # filter out the proposals with low confidence score
+            keep_idxs = topk_scores > self.conf_thresh
+            scores = topk_scores[keep_idxs]
+            topk_idxs = topk_idxs[keep_idxs]
+
+            anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+            labels = topk_idxs % self.num_classes
+
+            bboxes = box_pred_i[anchor_idxs]
+
+            all_scores.append(scores)
+            all_labels.append(labels)
+            all_bboxes.append(bboxes)
+
+        scores = torch.cat(all_scores)
+        labels = torch.cat(all_labels)
+        bboxes = torch.cat(all_bboxes)
+
+        # to cpu & numpy
+        scores = scores.cpu().numpy()
+        labels = labels.cpu().numpy()
+        bboxes = bboxes.cpu().numpy()
+
+        # nms
+        scores, labels, bboxes = multiclass_nms(
+            scores, labels, bboxes, self.nms_thresh, self.num_classes, False)
+
+        return bboxes, scores, labels
+
+
+    # ---------------------- Main Process for Inference ----------------------
+    @torch.no_grad()
+    def inference_single_image(self, x):
+        # ---------------- Backbone ----------------
+        pyramid_feats = self.backbone(x)
+
+        # ---------------- Neck: SPP ----------------
+        pyramid_feats[-1] = self.neck(pyramid_feats[-1])
+
+        # ---------------- Neck: PaFPN ----------------
+        pyramid_feats = self.fpn(pyramid_feats)
+
+        # ---------------- Heads ----------------
+        cls_feats, reg_feats = self.det_heads(pyramid_feats)
+
+        # ---------------- Preds ----------------
+        outputs = self.pred_layers(cls_feats, reg_feats)
+
+        all_cls_preds = outputs['pred_cls']
+        all_box_preds = outputs['pred_box']
+
+        if self.deploy:
+            cls_preds = torch.cat(all_cls_preds, dim=1)[0]
+            box_preds = torch.cat(all_box_preds, dim=1)[0]
+            scores = cls_preds.sigmoid()
+            bboxes = box_preds
+            # [n_anchors_all, 4 + C]
+            outputs = torch.cat([bboxes, scores], dim=-1)
+
+            return outputs
+        else:
+            # post process
+            bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
+        
+            return bboxes, scores, labels
+
+
+    def forward(self, x):
+        if not self.trainable:
+            return self.inference_single_image(x)
+        else:
+            # ---------------- Backbone ----------------
+            pyramid_feats = self.backbone(x)
+
+            # ---------------- Neck: SPP ----------------
+            pyramid_feats[-1] = self.neck(pyramid_feats[-1])
+
+            # ---------------- Neck: PaFPN ----------------
+            pyramid_feats = self.fpn(pyramid_feats)
+
+            # ---------------- Heads ----------------
+            cls_feats, reg_feats = self.det_heads(pyramid_feats)
+
+            # ---------------- Preds ----------------
+            outputs = self.pred_layers(cls_feats, reg_feats)
+            
+            return outputs 
+        

+ 127 - 0
models/detectors/lodet/lodet_backbone.py

@@ -0,0 +1,127 @@
+import torch
+import torch.nn as nn
+try:
+    from .lodet_basic import Conv, SMBlock
+except:
+    from lodet_basic import Conv, SMBlock
+
+
+
+model_urls = {
+    'smnet': None,
+}
+
+
+# ---------------------------- Backbones ----------------------------
+class ScaleModulationNet(nn.Module):
+    def __init__(self, act_type='silu', norm_type='BN', depthwise=False):
+        super(ScaleModulationNet, self).__init__()
+        self.feat_dims = [128, 256, 256]
+        
+        # P1/2
+        self.layer_1 = Conv(3, 32, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type)
+
+        # P2/4
+        self.layer_2 = nn.Sequential(   
+            nn.MaxPool2d((2, 2), stride=2),             
+            SMBlock(32, 64, 0.5, act_type, norm_type, depthwise)
+        )
+        # P3/8
+        self.layer_3 = nn.Sequential(
+            nn.MaxPool2d((2, 2), stride=2),             
+            SMBlock(64, 128, 0.5, act_type, norm_type, depthwise)
+        )
+        # P4/16
+        self.layer_4 = nn.Sequential(
+            nn.MaxPool2d((2, 2), stride=2),             
+            SMBlock(128, 256, 0.5, act_type, norm_type, depthwise)
+        )
+        # P5/32
+        self.layer_5 = nn.Sequential(
+            nn.MaxPool2d((2, 2), stride=2),             
+            SMBlock(256, 256, 0.25, act_type, norm_type, depthwise)
+        )
+
+
+    def forward(self, x):
+        c1 = self.layer_1(x)
+        c2 = self.layer_2(c1)
+        c3 = self.layer_3(c2)
+        c4 = self.layer_4(c3)
+        c5 = self.layer_5(c4)
+
+        outputs = [c3, c4, c5]
+
+        return outputs
+
+
+# ---------------------------- Functions ----------------------------
+## load pretrained weight
+def load_weight(model, model_name):
+    # load weight
+    print('Loading pretrained weight ...')
+    url = model_urls[model_name]
+    if url is not None:
+        checkpoint = torch.hub.load_state_dict_from_url(
+            url=url, map_location="cpu", check_hash=True)
+        # checkpoint state dict
+        checkpoint_state_dict = checkpoint.pop("model")
+        # model state dict
+        model_state_dict = model.state_dict()
+        # check
+        for k in list(checkpoint_state_dict.keys()):
+            if k in model_state_dict:
+                shape_model = tuple(model_state_dict[k].shape)
+                shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
+                if shape_model != shape_checkpoint:
+                    checkpoint_state_dict.pop(k)
+            else:
+                checkpoint_state_dict.pop(k)
+                print(k)
+
+        model.load_state_dict(checkpoint_state_dict)
+    else:
+        print('No pretrained for {}'.format(model_name))
+
+    return model
+
+
+## build SMnet
+def build_backbone(cfg, pretrained=False): 
+    # model
+    backbone = ScaleModulationNet(
+        act_type=cfg['bk_act'],
+        norm_type=cfg['bk_norm'],
+        depthwise=cfg['bk_dpw']
+        )
+    # check whether to load imagenet pretrained weight
+    if pretrained:
+        backbone = load_weight(backbone, model_name='smnet')
+    feat_dims = backbone.feat_dims
+
+    return backbone, feat_dims
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        'pretrained': True,
+        'bk_act': 'silu',
+        'bk_norm': 'BN',
+        'bk_dpw': True,
+    }
+    model, feats = build_backbone(cfg)
+    x = torch.randn(1, 3, 640, 640)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    for out in outputs:
+        print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 165 - 0
models/detectors/lodet/lodet_basic.py

@@ -0,0 +1,165 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+
+# ---------------------------- 2D CNN ----------------------------
+class SiLU(nn.Module):
+    """export-friendly version of nn.SiLU()"""
+
+    @staticmethod
+    def forward(x):
+        return x * torch.sigmoid(x)
+
+
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type is None:
+        return nn.Identity()
+
+
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+
+
+# Basic conv layer
+class Conv(nn.Module):
+    def __init__(self, 
+                 c1,                   # in channels
+                 c2,                   # out channels 
+                 k=1,                  # kernel size 
+                 p=0,                  # padding
+                 s=1,                  # padding
+                 d=1,                  # dilation
+                 act_type='lrelu',     # activation
+                 norm_type='BN',       # normalization
+                 depthwise=False):
+        super(Conv, self).__init__()
+        convs = []
+        add_bias = False if norm_type else True
+        p = p if d == 1 else d
+        if depthwise:
+            convs.append(get_conv2d(c1, c1, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
+            # depthwise conv
+            if norm_type:
+                convs.append(get_norm(norm_type, c1))
+            if act_type:
+                convs.append(get_activation(act_type))
+            # pointwise conv
+            convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+
+        else:
+            convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+            
+        self.convs = nn.Sequential(*convs)
+
+
+    def forward(self, x):
+        return self.convs(x)
+
+
+# ---------------------------- Core Modules ----------------------------
+## Scale Modulation Block
+class SMBlock(nn.Module):
+    def __init__(self, in_dim, out_dim, expand_ratio=0.5, act_type='silu', norm_type='BN', depthwise=False):
+        super(SMBlock, self).__init__()
+        # -------------- Basic parameters --------------
+        self.in_dim = in_dim
+        self.out_dim = out_dim
+        self.expand_ratio = expand_ratio
+        self.inter_dim = round(in_dim * expand_ratio)
+        # -------------- Network parameters --------------
+        ## Input proj
+        self.cv1 = Conv(in_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv2 = Conv(in_dim, self.inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        ## Scale Modulation
+        self.sm1 = Conv(self.inter_dim, self.inter_dim, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.sm2 = Conv(self.inter_dim, self.inter_dim, k=5, p=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.sm3 = Conv(self.inter_dim, self.inter_dim, k=7, p=3, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        ## Output proj
+        self.cv3 = Conv(self.inter_dim*4, out_dim, k=1, act_type=act_type, norm_type=norm_type)
+
+
+    def channel_shuffle(self, x, groups):
+        # type: (torch.Tensor, int) -> torch.Tensor
+        batchsize, num_channels, height, width = x.data.size()
+        per_group_dim = num_channels // groups
+
+        # reshape
+        x = x.view(batchsize, groups, per_group_dim, height, width)
+
+        x = torch.transpose(x, 1, 2).contiguous()
+
+        # flatten
+        x = x.view(batchsize, -1, height, width)
+
+        return x
+    
+
+    def forward(self, x):
+        x1 = self.cv1(x)
+        x2 = self.sm1(self.cv2(x))
+        x3 = self.sm2(x2)
+        x4 = self.sm3(x3)
+        out = torch.cat([x1, x2, x3, x4], dim=1)
+        out = self.channel_shuffle(out, groups=4)
+
+        out = self.cv3(out)
+
+        return out
+
+
+# ---------------------------- FPN Modules ----------------------------
+## build fpn's core block
+def build_fpn_block(cfg, in_dim, out_dim):
+    if cfg['fpn_core_block'] == 'smblock':
+        layer = SMBlock(in_dim=in_dim,
+                        out_dim=out_dim,
+                        expand_ratio=cfg['fpn_expand_ratio'],
+                        act_type=cfg['fpn_act'],
+                        norm_type=cfg['fpn_norm'],
+                        depthwise=cfg['fpn_depthwise']
+                        )
+        
+    return layer
+
+## build fpn's reduce layer
+def build_reduce_layer(cfg, in_dim, out_dim):
+    if cfg['fpn_reduce_layer'] == 'conv':
+        layer = Conv(in_dim, out_dim, k=1, act_type=cfg['fpn_act'], norm_type=cfg['fpn_norm'])
+        
+    return layer
+
+## build fpn's downsample layer
+def build_downsample_layer(cfg, in_dim, out_dim):
+    if cfg['fpn_downsample_layer'] == 'conv':
+        layer = Conv(in_dim, out_dim, k=3, s=2, p=1, act_type=cfg['fpn_act'], norm_type=cfg['fpn_norm'])
+    elif cfg['fpn_downsample_layer'] == 'maxpool':
+        assert in_dim == out_dim
+        layer = nn.MaxPool2d((2, 2), stride=2)
+        
+    return layer

+ 117 - 0
models/detectors/lodet/lodet_head.py

@@ -0,0 +1,117 @@
+import torch
+import torch.nn as nn
+
+from .lodet_basic import Conv
+
+
+# Single-level Head
+class SingleLevelHead(nn.Module):
+    def __init__(self, in_dim, out_dim, num_classes, num_cls_head, num_reg_head, act_type, norm_type, depthwise):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.in_dim = in_dim
+        self.num_classes = num_classes
+        self.num_cls_head = num_cls_head
+        self.num_reg_head = num_reg_head
+        self.act_type = act_type
+        self.norm_type = norm_type
+        self.depthwise = depthwise
+        
+        # --------- Network Parameters ----------
+        ## cls head
+        cls_feats = []
+        self.cls_head_dim = max(out_dim, num_classes)
+        for i in range(num_cls_head):
+            if i == 0:
+                cls_feats.append(
+                    Conv(in_dim, self.cls_head_dim, k=3, p=1, s=1, 
+                         act_type=act_type,
+                         norm_type=norm_type,
+                         depthwise=depthwise)
+                        )
+            else:
+                cls_feats.append(
+                    Conv(self.cls_head_dim, self.cls_head_dim, k=3, p=1, s=1, 
+                        act_type=act_type,
+                        norm_type=norm_type,
+                        depthwise=depthwise)
+                        )      
+        ## reg head
+        reg_feats = []
+        self.reg_head_dim = out_dim
+        for i in range(num_reg_head):
+            if i == 0:
+                reg_feats.append(
+                    Conv(in_dim, self.reg_head_dim, k=3, p=1, s=1, 
+                         act_type=act_type,
+                         norm_type=norm_type,
+                         depthwise=depthwise)
+                        )
+            else:
+                reg_feats.append(
+                    Conv(self.reg_head_dim, self.reg_head_dim, k=3, p=1, s=1, 
+                         act_type=act_type,
+                         norm_type=norm_type,
+                         depthwise=depthwise)
+                        )
+        self.cls_feats = nn.Sequential(*cls_feats)
+        self.reg_feats = nn.Sequential(*reg_feats)
+
+
+    def forward(self, x):
+        """
+            in_feats: (Tensor) [B, C, H, W]
+        """
+        cls_feats = self.cls_feats(x)
+        reg_feats = self.reg_feats(x)
+
+        return cls_feats, reg_feats
+    
+
+# Multi-level Head
+class MultiLevelHead(nn.Module):
+    def __init__(self, cfg, in_dims, out_dim, num_classes=80, num_levels=3):
+        super().__init__()
+        ## ----------- Network Parameters -----------
+        self.multi_level_heads = nn.ModuleList(
+            [SingleLevelHead(
+                in_dims[level],
+                out_dim,
+                num_classes,
+                cfg['num_cls_head'],
+                cfg['num_reg_head'],
+                cfg['head_act'],
+                cfg['head_norm'],
+                cfg['head_depthwise'])
+                for level in range(num_levels)
+            ])
+        # --------- Basic Parameters ----------
+        self.in_dims = in_dims
+        self.num_classes = num_classes
+
+        self.cls_head_dim = self.multi_level_heads[0].cls_head_dim
+        self.reg_head_dim = self.multi_level_heads[0].reg_head_dim
+
+
+    def forward(self, feats):
+        """
+            feats: List[(Tensor)] [[B, C, H, W], ...]
+        """
+        cls_feats = []
+        reg_feats = []
+        for feat, head in zip(feats, self.multi_level_heads):
+            # ---------------- Pred ----------------
+            cls_feat, reg_feat = head(feat)
+
+            cls_feats.append(cls_feat)
+            reg_feats.append(reg_feat)
+
+        return cls_feats, reg_feats
+    
+
+# build detection head
+def build_det_head(cfg, in_dim, out_dim, num_classes=80, num_levels=3):
+    if cfg['head'] == 'decoupled_head':
+        head = MultiLevelHead(cfg, in_dim, out_dim, num_classes, num_levels) 
+
+    return head

+ 71 - 0
models/detectors/lodet/lodet_neck.py

@@ -0,0 +1,71 @@
+import torch
+import torch.nn as nn
+
+from .lodet_basic import Conv
+
+
+# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
+class SPPF(nn.Module):
+    """
+        This code referenced to https://github.com/ultralytics/yolov5
+    """
+    def __init__(self, cfg, in_dim, out_dim, expand_ratio=0.5):
+        super().__init__()
+        inter_dim = int(in_dim * expand_ratio)
+        self.out_dim = out_dim
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=cfg['neck_act'], norm_type=cfg['neck_norm'])
+        self.cv2 = Conv(inter_dim * 4, out_dim, k=1, act_type=cfg['neck_act'], norm_type=cfg['neck_norm'])
+        self.m = nn.MaxPool2d(kernel_size=cfg['pooling_size'], stride=1, padding=cfg['pooling_size'] // 2)
+
+    def forward(self, x):
+        x = self.cv1(x)
+        y1 = self.m(x)
+        y2 = self.m(y1)
+
+        return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
+
+
+# SPPF block with CSP module
+class SPPFBlockCSP(nn.Module):
+    """
+        CSP Spatial Pyramid Pooling Block
+    """
+    def __init__(self, cfg, in_dim, out_dim, expand_ratio):
+        super(SPPFBlockCSP, self).__init__()
+        inter_dim = int(in_dim * expand_ratio)
+        self.out_dim = out_dim
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=cfg['neck_act'], norm_type=cfg['neck_norm'])
+        self.cv2 = Conv(in_dim, inter_dim, k=1, act_type=cfg['neck_act'], norm_type=cfg['neck_norm'])
+        self.m = nn.Sequential(
+            Conv(inter_dim, inter_dim, k=3, p=1, 
+                 act_type=cfg['neck_act'], norm_type=cfg['neck_norm'], 
+                 depthwise=cfg['neck_depthwise']),
+            SPPF(cfg, inter_dim, inter_dim, expand_ratio=1.0),
+            Conv(inter_dim, inter_dim, k=3, p=1, 
+                 act_type=cfg['neck_act'], norm_type=cfg['neck_norm'], 
+                 depthwise=cfg['neck_depthwise'])
+        )
+        self.cv3 = Conv(inter_dim * 2, self.out_dim, k=1, act_type=cfg['neck_act'], norm_type=cfg['neck_norm'])
+
+        
+    def forward(self, x):
+        x1 = self.cv1(x)
+        x2 = self.cv2(x)
+        x3 = self.m(x2)
+        y = self.cv3(torch.cat([x1, x3], dim=1))
+
+        return y
+
+
+def build_neck(cfg, in_dim, out_dim):
+    model = cfg['neck']
+    print('==============================')
+    print('Neck: {}'.format(model))
+    # build neck
+    if model == 'sppf':
+        neck = SPPF(cfg, in_dim, out_dim, cfg['neck_expand_ratio'])
+    elif model == 'csp_sppf':
+        neck = SPPFBlockCSP(cfg, in_dim, out_dim, cfg['neck_expand_ratio'])
+
+    return neck
+        

+ 92 - 0
models/detectors/lodet/lodet_pafpn.py

@@ -0,0 +1,92 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .lodet_basic import (Conv, build_reduce_layer, build_downsample_layer, build_fpn_block)
+
+
+# YOLO-Style PaFPN
+class LodetPaFPN(nn.Module):
+    def __init__(self, cfg, in_dims=[128, 256, 256], out_dim=None):
+        super(LodetPaFPN, self).__init__()
+        # --------------------------- Basic Parameters ---------------------------
+        self.in_dims = in_dims
+        c3, c4, c5 = in_dims
+        
+        # --------------------------- Top-down FPN---------------------------
+        ## P5 -> P4
+        self.reduce_layer_1 = build_reduce_layer(cfg, c5, 128)
+        self.reduce_layer_2 = build_reduce_layer(cfg, c4, 128)
+        self.top_down_layer_1 = build_fpn_block(cfg, 128 + 128, 128)
+
+        ## P4 -> P3
+        self.reduce_layer_3 = build_reduce_layer(cfg, 128, 64)
+        self.reduce_layer_4 = build_reduce_layer(cfg, c3, 64)
+        self.top_down_layer_2 = build_fpn_block(cfg, 64 + 64, 64)
+
+        # --------------------------- Bottom-up FPN ---------------------------
+        ## P3 -> P4
+        self.downsample_layer_1 = build_downsample_layer(cfg, 64, 64)
+        self.bottom_up_layer_1 = build_fpn_block(cfg, 64 + 64, 128)
+
+        ## P4 -> P5
+        self.downsample_layer_2 = build_downsample_layer(cfg, 128, 128)
+        self.bottom_up_layer_2 = build_fpn_block(cfg, 128 + 128, 256)
+                
+        # --------------------------- Output proj ---------------------------
+        if out_dim is not None:
+            self.out_layers = nn.ModuleList([
+                Conv(in_dim, out_dim, k=1,
+                     act_type=cfg['fpn_act'], norm_type=cfg['fpn_norm'])
+                     for in_dim in [64, 128, 256]
+                     ])
+            self.out_dim = [out_dim] * 3
+        else:
+            self.out_layers = None
+            self.out_dim = self.in_dims
+
+
+    def forward(self, features):
+        c3, c4, c5 = features
+
+        # Top down
+        ## P5 -> P4
+        c6 = self.reduce_layer_1(c5)
+        c7 = self.reduce_layer_2(c4)
+        c8 = torch.cat([F.interpolate(c6, scale_factor=2.0), c7], dim=1)
+        c9 = self.top_down_layer_1(c8)
+        ## P4 -> P3
+        c10 = self.reduce_layer_3(c9)
+        c11 = self.reduce_layer_4(c3)
+        c12 = torch.cat([F.interpolate(c10, scale_factor=2.0), c11], dim=1)
+        c13 = self.top_down_layer_2(c12)
+
+        # Bottom up
+        # p3 -> P4
+        c14 = self.downsample_layer_1(c13)
+        c15 = torch.cat([c14, c10], dim=1)
+        c16 = self.bottom_up_layer_1(c15)
+        # P4 -> P5
+        c17 = self.downsample_layer_2(c16)
+        c18 = torch.cat([c17, c6], dim=1)
+        c19 = self.bottom_up_layer_2(c18)
+
+        out_feats = [c13, c16, c19] # [P3, P4, P5]
+        
+        # output proj layers
+        if self.out_layers is not None:
+            out_feats_proj = []
+            for feat, layer in zip(out_feats, self.out_layers):
+                out_feats_proj.append(layer(feat))
+            return out_feats_proj
+
+        return out_feats
+
+
+def build_fpn(cfg, in_dims, out_dim=None):
+    model = cfg['fpn']
+    # build pafpn
+    if model == 'lodet_pafpn':
+        fpn_net = LodetPaFPN(cfg, in_dims, out_dim)
+
+    return fpn_net

+ 154 - 0
models/detectors/lodet/lodet_pred.py

@@ -0,0 +1,154 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+# Single-level pred layer
+class SingleLevelPredLayer(nn.Module):
+    def __init__(self, cls_dim, reg_dim, num_classes, num_coords=4):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.cls_dim = cls_dim
+        self.reg_dim = reg_dim
+        self.num_classes = num_classes
+        self.num_coords = num_coords
+
+        # --------- Network Parameters ----------
+        self.cls_pred = nn.Conv2d(cls_dim, num_classes, kernel_size=1)
+        self.reg_pred = nn.Conv2d(reg_dim, num_coords, kernel_size=1)                
+
+        self.init_bias()
+        
+
+    def init_bias(self):
+        # Init bias
+        init_prob = 0.01
+        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+        # cls pred
+        b = self.cls_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        self.cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        # reg pred
+        b = self.reg_pred.bias.view(-1, )
+        b.data.fill_(1.0)
+        self.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        w = self.reg_pred.weight
+        w.data.fill_(0.)
+        self.reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
+
+
+    def forward(self, cls_feat, reg_feat):
+        """
+            in_feats: (Tensor) [B, C, H, W]
+        """
+        cls_pred = self.cls_pred(cls_feat)
+        reg_pred = self.reg_pred(reg_feat)
+
+        return cls_pred, reg_pred
+    
+
+# Multi-level pred layer
+class MultiLevelPredLayer(nn.Module):
+    def __init__(self, cls_dim, reg_dim, strides, num_classes, num_coords=4, num_levels=3, reg_max=16):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.cls_dim = cls_dim
+        self.reg_dim = reg_dim
+        self.strides = strides
+        self.num_classes = num_classes
+        self.num_coords = num_coords
+        self.num_levels = num_levels
+        self.reg_max = reg_max
+
+        # ----------- Network Parameters -----------
+        ## pred layers
+        self.multi_level_preds = nn.ModuleList(
+            [SingleLevelPredLayer(
+                cls_dim,
+                reg_dim,
+                num_classes,
+                num_coords * self.reg_max)
+                for _ in range(num_levels)
+            ])
+        ## proj conv
+        self.proj = nn.Parameter(torch.linspace(0, reg_max, reg_max), requires_grad=False)
+        self.proj_conv = nn.Conv2d(self.reg_max, 1, kernel_size=1, bias=False)
+        self.proj_conv.weight = nn.Parameter(self.proj.view([1, reg_max, 1, 1]).clone().detach(), requires_grad=False)
+
+
+    def generate_anchors(self, level, fmp_size):
+        """
+            fmp_size: (List) [H, W]
+        """
+        # generate grid cells
+        fmp_h, fmp_w = fmp_size
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        # [H, W, 2] -> [HW, 2]
+        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
+        anchors += 0.5  # add center offset
+        anchors *= self.strides[level]
+
+        return anchors
+        
+
+    def forward(self, cls_feats, reg_feats):
+        all_anchors = []
+        all_strides = []
+        all_cls_preds = []
+        all_reg_preds = []
+        all_box_preds = []
+        for level in range(self.num_levels):
+            # pred
+            cls_pred, reg_pred = self.multi_level_preds[level](
+                cls_feats[level], reg_feats[level])
+
+            # generate anchor boxes: [M, 4]
+            B, _, H, W = cls_pred.size()
+            fmp_size = [H, W]
+            anchors = self.generate_anchors(level, fmp_size)
+            anchors = anchors.to(cls_pred.device)
+            # stride tensor: [M, 1]
+            stride_tensor = torch.ones_like(anchors[..., :1]) * self.strides[level]
+            
+            # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
+            cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+            reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
+
+            # ----------------------- Decode bbox -----------------------
+            B, M = reg_pred.shape[:2]
+            # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
+            reg_pred_ = reg_pred.reshape([B, M, 4, self.reg_max])
+            # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
+            reg_pred_ = reg_pred_.permute(0, 3, 2, 1).contiguous()
+            # [B, reg_max, 4, M] -> [B, 1, 4, M]
+            reg_pred_ = self.proj_conv(F.softmax(reg_pred_, dim=1))
+            # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
+            reg_pred_ = reg_pred_.view(B, 4, M).permute(0, 2, 1).contiguous()
+            ## tlbr -> xyxy
+            x1y1_pred = anchors[None] - reg_pred_[..., :2] * self.strides[level]
+            x2y2_pred = anchors[None] + reg_pred_[..., 2:] * self.strides[level]
+            box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
+
+            all_cls_preds.append(cls_pred)
+            all_reg_preds.append(reg_pred)
+            all_box_preds.append(box_pred)
+            all_anchors.append(anchors)
+            all_strides.append(stride_tensor)
+        
+        # output dict
+        outputs = {"pred_cls": all_cls_preds,        # List(Tensor) [B, M, C]
+                   "pred_reg": all_reg_preds,        # List(Tensor) [B, M, 4*(reg_max)]
+                   "pred_box": all_box_preds,        # List(Tensor) [B, M, 4]
+                   "anchors": all_anchors,           # List(Tensor) [M, 2]
+                   "strides": self.strides,          # List(Int) = [8, 16, 32]
+                   "stride_tensor": all_strides      # List(Tensor) [M, 1]
+                   }
+
+        return outputs
+    
+
+# build detection head
+def build_pred_layer(cls_dim, reg_dim, strides, num_classes, num_coords=4, num_levels=3):
+    pred_layers = MultiLevelPredLayer(cls_dim, reg_dim, strides, num_classes, num_coords, num_levels) 
+
+    return pred_layers

+ 194 - 0
models/detectors/lodet/loss.py

@@ -0,0 +1,194 @@
+import torch
+import torch.nn.functional as F
+
+from utils.box_ops import  bbox2dist, get_ious
+from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
+
+from .matcher import AlignedSimOTA
+
+
+class Criterion(object):
+    def __init__(self, 
+                 cfg, 
+                 device, 
+                 num_classes=80):
+        self.cfg = cfg
+        self.device = device
+        self.num_classes = num_classes
+        # loss weight
+        self.loss_cls_weight = cfg['loss_cls_weight']
+        self.loss_box_weight = cfg['loss_box_weight']
+        self.loss_dfl_weight = cfg['loss_dfl_weight']
+        # matcher
+        matcher_config = cfg['matcher']
+        self.matcher = AlignedSimOTA(
+            num_classes=num_classes,
+            center_sampling_radius=matcher_config['center_sampling_radius'],
+            topk_candidate=matcher_config['topk_candicate']
+            )
+
+
+    def loss_classes(self, pred_cls, gt_label):
+        loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_label, reduction='none')
+
+        return loss_cls
+
+
+    def loss_bboxes(self, pred_box, gt_box):
+        # regression loss
+        ious = get_ious(pred_box, gt_box, "xyxy", 'giou')
+        loss_box = 1.0 - ious
+
+        return loss_box
+
+
+    def loss_dfl(self, pred_reg, gt_box, anchor, stride):
+        # rescale coords by stride
+        gt_box_s = gt_box / stride
+        anchor_s = anchor / stride
+
+        # compute deltas
+        gt_ltrb_s = bbox2dist(anchor_s, gt_box_s, self.cfg['reg_max'] - 1)
+
+        gt_left = gt_ltrb_s.to(torch.long)
+        gt_right = gt_left + 1
+
+        weight_left = gt_right.to(torch.float) - gt_ltrb_s
+        weight_right = 1 - weight_left
+
+        # loss left
+        loss_left = F.cross_entropy(
+            pred_reg.view(-1, self.cfg['reg_max']),
+            gt_left.view(-1),
+            reduction='none').view(gt_left.shape) * weight_left
+        # loss right
+        loss_right = F.cross_entropy(
+            pred_reg.view(-1, self.cfg['reg_max']),
+            gt_right.view(-1),
+            reduction='none').view(gt_left.shape) * weight_right
+
+        loss_dfl = (loss_left + loss_right).mean(-1, keepdim=True)
+            
+        return loss_dfl
+    
+    
+    def __call__(self, outputs, targets, epoch=0):        
+        """
+            outputs['pred_cls']: List(Tensor) [B, M, C]
+            outputs['pred_box']: List(Tensor) [B, M, 4]
+            outputs['strides']: List(Int) [8, 16, 32] output stride
+            targets: (List) [dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}, ...]
+        """
+        bs = outputs['pred_cls'][0].shape[0]
+        device = outputs['pred_cls'][0].device
+        fpn_strides = outputs['strides']
+        anchors = outputs['anchors']
+        num_anchors = sum([ab.shape[0] for ab in anchors])
+        # preds: [B, M, C]
+        cls_preds = torch.cat(outputs['pred_cls'], dim=1)
+        reg_preds = torch.cat(outputs['pred_reg'], dim=1)
+        box_preds = torch.cat(outputs['pred_box'], dim=1)
+
+        # --------------- label assignment ---------------
+        cls_targets = []
+        box_targets = []
+        fg_masks = []
+        for batch_idx in range(bs):
+            tgt_labels = targets[batch_idx]["labels"].to(device)
+            tgt_bboxes = targets[batch_idx]["boxes"].to(device)
+
+            # check target
+            if len(tgt_labels) == 0 or tgt_bboxes.max().item() == 0.:
+                # There is no valid gt
+                cls_target = cls_preds.new_zeros((num_anchors, self.num_classes))
+                box_target = cls_preds.new_zeros((0, 4))
+                fg_mask = cls_preds.new_zeros(num_anchors).bool()
+            else:
+                (
+                    fg_mask,
+                    assigned_labels,
+                    assigned_ious,
+                    assigned_indexs
+                ) = self.matcher(
+                    fpn_strides = fpn_strides,
+                    anchors = anchors,
+                    pred_cls = cls_preds[batch_idx], 
+                    pred_box = box_preds[batch_idx],
+                    tgt_labels = tgt_labels,
+                    tgt_bboxes = tgt_bboxes
+                    )
+                # prepare cls targets
+                assigned_labels = F.one_hot(assigned_labels.long(), self.num_classes)
+                assigned_labels = assigned_labels * assigned_ious.unsqueeze(-1)
+                cls_target = cls_preds.new_zeros((num_anchors, self.num_classes))
+                cls_target[fg_mask] = assigned_labels
+                # prepare box targets
+                box_target = tgt_bboxes[assigned_indexs]
+
+            cls_targets.append(cls_target)
+            box_targets.append(box_target)
+            fg_masks.append(fg_mask)
+
+        cls_targets = torch.cat(cls_targets, 0)
+        box_targets = torch.cat(box_targets, 0)
+        fg_masks = torch.cat(fg_masks, 0)
+        num_fgs = fg_masks.sum()
+
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_fgs)
+        num_fgs = (num_fgs / get_world_size()).clamp(1.0)
+        
+        # ------------------ Classification loss ------------------
+        cls_preds = cls_preds.view(-1, self.num_classes)
+        loss_cls = self.loss_classes(cls_preds, cls_targets)
+        loss_cls = loss_cls.sum() / num_fgs
+
+        # ------------------ Regression loss ------------------
+        box_preds_pos = box_preds.view(-1, 4)[fg_masks]
+        loss_box = self.loss_bboxes(box_preds_pos, box_targets)
+        loss_box = loss_box.sum() / num_fgs
+
+        # ------------------ Distribution focal loss  ------------------
+        ## process anchors
+        anchors = torch.cat(anchors, dim=0)
+        anchors = anchors[None].repeat(bs, 1, 1).view(-1, 2)
+        ## process stride tensors
+        strides = torch.cat(outputs['stride_tensor'], dim=0)
+        strides = strides.unsqueeze(0).repeat(bs, 1, 1).view(-1, 1)
+        ## fg preds
+        reg_preds_pos = reg_preds.view(-1, 4*self.cfg['reg_max'])[fg_masks]
+        anchors_pos = anchors[fg_masks]
+        strides_pos = strides[fg_masks]
+        ## compute dfl
+        loss_dfl = self.loss_dfl(reg_preds_pos, box_targets, anchors_pos, strides_pos)
+        loss_dfl = loss_dfl.sum() / num_fgs
+
+        # total loss
+        losses = self.loss_cls_weight * loss_cls + \
+                 self.loss_box_weight * loss_box + \
+                 self.loss_dfl_weight * loss_dfl
+
+        loss_dict = dict(
+                loss_cls = loss_cls,
+                loss_box = loss_box,
+                loss_dfl = loss_dfl,
+                losses = losses
+        )
+
+        return loss_dict
+    
+
+def build_criterion(cfg, device, num_classes):
+    criterion = Criterion(
+        cfg=cfg,
+        device=device,
+        num_classes=num_classes
+        )
+
+    return criterion
+
+
+if __name__ == "__main__":
+    pass

+ 184 - 0
models/detectors/lodet/matcher.py

@@ -0,0 +1,184 @@
+# ---------------------------------------------------------------------
+# Copyright (c) Megvii Inc. All rights reserved.
+# ---------------------------------------------------------------------
+
+
+import torch
+import torch.nn.functional as F
+from utils.box_ops import *
+
+
+class AlignedSimOTA(object):
+    """
+        This code referenced to https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/models/yolo_head.py
+    """
+    def __init__(self, num_classes, center_sampling_radius, topk_candidate ):
+        self.num_classes = num_classes
+        self.center_sampling_radius = center_sampling_radius
+        self.topk_candidate = topk_candidate
+
+
+    @torch.no_grad()
+    def __call__(self, 
+                 fpn_strides, 
+                 anchors, 
+                 pred_cls, 
+                 pred_box, 
+                 tgt_labels,
+                 tgt_bboxes):
+        # [M,]
+        strides_tensor = torch.cat([torch.ones_like(anchor_i[:, 0]) * stride_i
+                                for stride_i, anchor_i in zip(fpn_strides, anchors)], dim=-1)
+        # List[F, M, 2] -> [M, 2]
+        anchors = torch.cat(anchors, dim=0)
+        num_anchor = anchors.shape[0]        
+        num_gt = len(tgt_labels)
+
+        # ----------------------- Find inside points -----------------------
+        fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(
+            tgt_bboxes, anchors, strides_tensor, num_anchor, num_gt)
+        cls_preds = pred_cls[fg_mask].float()   # [Mp, C]
+        box_preds = pred_box[fg_mask].float()   # [Mp, 4]
+
+        # ----------------------- Reg cost -----------------------
+        pair_wise_ious, _ = box_iou(tgt_bboxes, box_preds)      # [N, Mp]
+        reg_cost = -torch.log(pair_wise_ious + 1e-8)            # [N, Mp]
+
+        # ----------------------- Cls cost -----------------------
+        with torch.cuda.amp.autocast(enabled=False):
+            # [Mp, C] -> [N, Mp, C]
+            score_preds = cls_preds.sigmoid_().unsqueeze(0).repeat(num_gt, 1, 1)
+            # prepare cls_target
+            cls_targets = F.one_hot(tgt_labels.long(), self.num_classes).float()
+            cls_targets = cls_targets.unsqueeze(1).repeat(1, score_preds.size(1), 1)
+            cls_targets *= pair_wise_ious.unsqueeze(-1)  # iou-aware
+            # [N, Mp]
+            cls_cost = F.binary_cross_entropy(score_preds, cls_targets, reduction="none").sum(-1)
+        del score_preds
+
+        #----------------------- Dynamic K-Matching -----------------------
+        cost_matrix = (
+            cls_cost
+            + 3.0 * reg_cost
+            + 100000.0 * (~is_in_boxes_and_center)
+        ) # [N, Mp]
+
+        (
+            assigned_labels,         # [num_fg,]
+            assigned_ious,           # [num_fg,]
+            assigned_indexs,         # [num_fg,]
+        ) = self.dynamic_k_matching(
+            cost_matrix,
+            pair_wise_ious,
+            tgt_labels,
+            num_gt,
+            fg_mask
+            )
+        del cls_cost, cost_matrix, pair_wise_ious, reg_cost
+
+        return fg_mask, assigned_labels, assigned_ious, assigned_indexs
+
+
+    def get_in_boxes_info(
+        self,
+        gt_bboxes,   # [N, 4]
+        anchors,     # [M, 2]
+        strides,     # [M,]
+        num_anchors, # M
+        num_gt,      # N
+        ):
+        # anchor center
+        x_centers = anchors[:, 0]
+        y_centers = anchors[:, 1]
+
+        # [M,] -> [1, M] -> [N, M]
+        x_centers = x_centers.unsqueeze(0).repeat(num_gt, 1)
+        y_centers = y_centers.unsqueeze(0).repeat(num_gt, 1)
+
+        # [N,] -> [N, 1] -> [N, M]
+        gt_bboxes_l = gt_bboxes[:, 0].unsqueeze(1).repeat(1, num_anchors) # x1
+        gt_bboxes_t = gt_bboxes[:, 1].unsqueeze(1).repeat(1, num_anchors) # y1
+        gt_bboxes_r = gt_bboxes[:, 2].unsqueeze(1).repeat(1, num_anchors) # x2
+        gt_bboxes_b = gt_bboxes[:, 3].unsqueeze(1).repeat(1, num_anchors) # y2
+
+        b_l = x_centers - gt_bboxes_l
+        b_r = gt_bboxes_r - x_centers
+        b_t = y_centers - gt_bboxes_t
+        b_b = gt_bboxes_b - y_centers
+        bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
+
+        is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
+        is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
+        # in fixed center
+        center_radius = self.center_sampling_radius
+
+        # [N, 2]
+        gt_centers = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) * 0.5
+        
+        # [1, M]
+        center_radius_ = center_radius * strides.unsqueeze(0)
+
+        gt_bboxes_l = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # x1
+        gt_bboxes_t = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # y1
+        gt_bboxes_r = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # x2
+        gt_bboxes_b = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # y2
+
+        c_l = x_centers - gt_bboxes_l
+        c_r = gt_bboxes_r - x_centers
+        c_t = y_centers - gt_bboxes_t
+        c_b = gt_bboxes_b - y_centers
+        center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
+        is_in_centers = center_deltas.min(dim=-1).values > 0.0
+        is_in_centers_all = is_in_centers.sum(dim=0) > 0
+
+        # in boxes and in centers
+        is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
+
+        is_in_boxes_and_center = (
+            is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
+        )
+        return is_in_boxes_anchor, is_in_boxes_and_center
+    
+    
+    def dynamic_k_matching(
+        self, 
+        cost, 
+        pair_wise_ious, 
+        gt_classes, 
+        num_gt, 
+        fg_mask
+        ):
+        # Dynamic K
+        # ---------------------------------------------------------------
+        matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
+
+        ious_in_boxes_matrix = pair_wise_ious
+        n_candidate_k = min(self.topk_candidate, ious_in_boxes_matrix.size(1))
+        topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
+        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
+        dynamic_ks = dynamic_ks.tolist()
+        for gt_idx in range(num_gt):
+            _, pos_idx = torch.topk(
+                cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
+            )
+            matching_matrix[gt_idx][pos_idx] = 1
+
+        del topk_ious, dynamic_ks, pos_idx
+
+        anchor_matching_gt = matching_matrix.sum(0)
+        if (anchor_matching_gt > 1).sum() > 0:
+            _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
+            matching_matrix[:, anchor_matching_gt > 1] *= 0
+            matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
+        fg_mask_inboxes = matching_matrix.sum(0) > 0
+
+        fg_mask[fg_mask.clone()] = fg_mask_inboxes
+
+        assigned_indexs = matching_matrix[:, fg_mask_inboxes].argmax(0)
+        assigned_labels = gt_classes[assigned_indexs]
+
+        assigned_ious = (matching_matrix * pair_wise_ious).sum(0)[
+            fg_mask_inboxes
+        ]
+        return assigned_labels, assigned_ious, assigned_indexs
+