ソースを参照

modify YOLOX2 & add E2e-YOLO

yjh0410 2 年 前
コミット
1ee9f93827

+ 5 - 5
README.md

@@ -160,14 +160,14 @@ python train.py --cuda -d coco --root path/to/COCO -m yolov1 -bs 16 --max_epoch
 | YOLOX2-M |  640  |  300  |                        |                   |                   |                    |  |
 | YOLOX2-L |  640  |  300  |                        |                   |                   |                    |  |
 
-* ETE-YOLO (End-to-End YOLO without NMS):
+* E2E-YOLO (End-to-End YOLO without NMS):
 
 | Model      | Scale | Epoch | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
 |------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|
-| ETE-YOLO-N |  640  |  300  |                        |                   |                   |                    |  |
-| ETE-YOLO-S |  640  |  300  |                        |                   |                   |                    |  |
-| ETE-YOLO-M |  640  |  300  |                        |                   |                   |                    |  |
-| ETE-YOLO-L |  640  |  300  |                        |                   |                   |                    |  |
+| E2E-YOLO-N |  640  |  300  |                        |                   |                   |                    |  |
+| E2E-YOLO-S |  640  |  300  |                        |                   |                   |                    |  |
+| E2E-YOLO-M |  640  |  300  |                        |                   |                   |                    |  |
+| E2E-YOLO-L |  640  |  300  |                        |                   |                   |                    |  |
 
 * Redesigned RT-DETR:
 

+ 4 - 0
config/__init__.py

@@ -83,6 +83,7 @@ from .model_config.yolov7_config import yolov7_cfg
 from .model_config.yolox_config import yolox_cfg
 from .model_config.yolox2_config import yolox2_cfg
 from .model_config.rtdetr_config import rtdetr_cfg
+from .model_config.e2eyolo_config import e2eyolo_cfg
 
 
 def build_model_config(args):
@@ -115,6 +116,9 @@ def build_model_config(args):
     # RT-DETR
     elif args.model in ['rtdetr_n', 'rtdetr_s', 'rtdetr_m', 'rtdetr_l', 'rtdetr_x']:
         cfg = rtdetr_cfg[args.model]
+    # E2E-YOLO
+    elif args.model in ['e2eyolo_n', 'e2eyolo_s', 'e2eyolo_m', 'e2eyolo_l', 'e2eyolo_x']:
+        cfg = e2eyolo_cfg[args.model]
 
     return cfg
 

+ 74 - 0
config/model_config/e2eyolo_config.py

@@ -0,0 +1,74 @@
+# e2eyolo Config
+
+
+e2eyolo_cfg = {
+    'e2eyolo_n':{
+        # ---------------- Model config ----------------
+        ## Backbone
+        'backbone': 'elannet',
+        'pretrained': True,
+        'bk_act': 'silu',
+        'bk_norm': 'BN',
+        'bk_dpw': False,
+        'width': 0.25,
+        'depth': 0.34,
+        'stride': [8, 16, 32],  # P3, P4, P5
+        'max_stride': 32,
+        ## Neck: SPP
+        'neck': 'sppf',
+        'neck_expand_ratio': 0.5,
+        'pooling_size': 5,
+        'neck_act': 'silu',
+        'neck_norm': 'BN',
+        'neck_depthwise': False,
+        ## Neck: PaFPN
+        'fpn': 'yolo_pafpn',
+        'fpn_reduce_layer': 'Conv',
+        'fpn_downsample_layer': 'Conv',
+        'fpn_core_block': 'elanblock',
+        'fpn_act': 'silu',
+        'fpn_norm': 'BN',
+        'fpn_depthwise': False,
+        ## Head
+        'head': 'decoupled_head',
+        'head_act': 'silu',
+        'head_norm': 'BN',
+        'num_cls_head': 2,
+        'num_reg_head': 2,
+        'head_depthwise': False,
+        'head_groups': 1,
+        # ---------------- Train config ----------------
+        ## input
+        'multi_scale': [0.5, 1.5],   # 320 -> 960
+        'trans_type': 'yolox_nano',
+        # ---------------- Assignment config ----------------
+        ## matcher
+        'matcher': {'topk': 10,
+                    'alpha': 0.5,
+                    'beta': 6.0},
+        # ---------------- Loss config ----------------
+        ## loss weight
+        'loss_obj_weight': 1.0,
+        'loss_cls_weight': 1.0,
+        'loss_box_weight': 5.0,
+        # ---------------- Train config ----------------
+        ## close strong augmentation
+        'no_aug_epoch': 20,
+        'trainer_type': 'rtmdet',
+        ## optimizer
+        'optimizer': 'adamw',      # optional: sgd, AdamW
+        'momentum': None,          # SGD: 0.9;      AdamW: None
+        'weight_decay': 5e-2,      # SGD: 5e-4;     AdamW: 5e-2
+        'clip_grad': 15,           # SGD: 10.0;     AdamW: -1
+        ## model EMA
+        'ema_decay': 0.9998,       # SGD: 0.9999;   AdamW: 0.9998
+        'ema_tau': 2000,
+        ## lr schedule
+        'scheduler': 'linear',
+        'lr0': 0.001,               # SGD: 0.01;     AdamW: 0.001
+        'lrf': 0.01,               # SGD: 0.01;     AdamW: 0.01
+        'warmup_momentum': 0.8,
+        'warmup_bias_lr': 0.1,
+    },
+
+}

+ 5 - 0
models/detectors/__init__.py

@@ -11,6 +11,7 @@ from .yolov7.build import build_yolov7
 from .yolox.build import build_yolox
 from .yolox2.build import build_yolox2
 from .rtdetr.build import build_rtdetr
+from .e2eyolo.build import build_e2eyolo
 
 
 # build object detector
@@ -56,6 +57,10 @@ def build_model(args,
     elif args.model in ['rtdetr_n', 'rtdetr_s', 'rtdetr_m', 'rtdetr_l', 'rtdetr_x']:
         model, criterion = build_rtdetr(
             args, model_cfg, device, num_classes, trainable, deploy)
+    # E2E-YOLO
+    elif args.model in ['e2eyolo_n', 'e2eyolo_s', 'e2eyolo_m', 'e2eyolo_l', 'e2eyolo_x']:
+        model, criterion = build_e2eyolo(
+            args, model_cfg, device, num_classes, trainable, deploy)
 
     if trainable:
         # Load pretrained weight

+ 61 - 0
models/detectors/e2eyolo/build.py

@@ -0,0 +1,61 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+import torch
+import torch.nn as nn
+
+from .loss import build_criterion
+from .e2eyolo import E2EYOLO
+
+
+# build object detector
+def build_e2eyolo(args, cfg, device, num_classes=80, trainable=False, deploy=False):
+    print('==============================')
+    print('Build {} ...'.format(args.model.upper()))
+        
+    # -------------- Build YOLO --------------
+    model = E2EYOLO(
+        cfg=cfg,
+        device=device, 
+        num_classes=num_classes,
+        trainable=trainable,
+        conf_thresh=args.conf_thresh,
+        nms_thresh=args.nms_thresh,
+        topk=args.topk,
+        deploy=deploy
+        )
+
+    # -------------- Initialize YOLO --------------
+    for m in model.modules():
+        if isinstance(m, nn.BatchNorm2d):
+            m.eps = 1e-3
+            m.momentum = 0.03    
+    # Init head
+    init_prob = 0.01
+    bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+    ## obj pred
+    for obj_pred in model.obj_preds:
+        b = obj_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+    ## cls pred
+    for cls_pred in model.cls_preds:
+        b = cls_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+    ## reg pred
+    for reg_pred in model.reg_preds:
+        b = reg_pred.bias.view(-1, )
+        b.data.fill_(1.0)
+        reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        w = reg_pred.weight
+        w.data.fill_(0.)
+        reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
+
+
+    # -------------- Build criterion --------------
+    criterion = None
+    if trainable:
+        # build criterion for training
+        criterion = build_criterion(cfg, device, num_classes)
+    return model, criterion

+ 260 - 0
models/detectors/e2eyolo/e2eyolo.py

@@ -0,0 +1,260 @@
+# --------------- Torch components ---------------
+import torch
+import torch.nn as nn
+
+# --------------- Model components ---------------
+from .e2eyolo_backbone import build_backbone
+from .e2eyolo_neck import build_neck
+from .e2eyolo_pafpn import build_fpn
+from .e2eyolo_head import build_head
+
+# --------------- External components ---------------
+from utils.misc import multiclass_nms
+
+
+# E2E-YOLO
+class E2EYOLO(nn.Module):
+    def __init__(self, 
+                 cfg,
+                 device, 
+                 num_classes = 20, 
+                 conf_thresh = 0.05,
+                 nms_thresh = 0.6,
+                 trainable = False, 
+                 topk = 1000,
+                 deploy = False):
+        super(E2EYOLO, self).__init__()
+        # ---------------------- Basic Parameters ----------------------
+        self.cfg = cfg
+        self.device = device
+        self.stride = cfg['stride']
+        self.num_classes = num_classes
+        self.trainable = trainable
+        self.conf_thresh = conf_thresh
+        self.nms_thresh = nms_thresh
+        self.topk = topk
+        self.deploy = deploy
+        self.head_dim = round(256*cfg['width'])
+        
+        # ---------------------- Network Parameters ----------------------
+        ## ----------- Backbone -----------
+        self.backbone, feats_dim = build_backbone(cfg, trainable&cfg['pretrained'])
+
+        ## ----------- Neck: SPP -----------
+        self.neck = build_neck(cfg=cfg, in_dim=feats_dim[-1], out_dim=feats_dim[-1])
+        feats_dim[-1] = self.neck.out_dim
+        
+        ## ----------- Neck: FPN -----------
+        self.fpn = build_fpn(cfg=cfg, in_dims=feats_dim, out_dim=round(256*cfg['width']))
+        self.fpn_dims = self.fpn.out_dim
+
+        ## ----------- Heads -----------
+        self.group_heads = build_head(cfg, self.fpn_dims, self.head_dim, num_classes) 
+
+        ## ----------- Preds -----------
+        self.obj_preds = nn.ModuleList(
+                            [nn.Conv2d(self.head_dim, 1, kernel_size=1) 
+                                for _ in range(len(self.stride))
+                              ]) 
+        self.cls_preds = nn.ModuleList(
+                            [nn.Conv2d(self.head_dim, num_classes, kernel_size=1) 
+                                for _ in range(len(self.stride))
+                              ]) 
+        self.reg_preds = nn.ModuleList(
+                            [nn.Conv2d(self.head_dim, 4, kernel_size=1) 
+                                for _ in range(len(self.stride))
+                              ])                 
+
+    # ---------------------- Basic Functions ----------------------
+    ## generate anchor points
+    def generate_anchors(self, level, fmp_size):
+        """
+            fmp_size: (List) [H, W]
+        """
+        # generate grid cells
+        fmp_h, fmp_w = fmp_size
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        # [H, W, 2] -> [HW, 2]
+        anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
+        anchor_xy += 0.5  # add center offset
+        anchor_xy *= self.stride[level]
+        anchors = anchor_xy.to(self.device)
+
+        return anchors
+        
+    ## post-process
+    def post_process(self, obj_preds, cls_preds, box_preds):
+        """
+        Input:
+            obj_preds: List(Tensor) [[H x W, 1], ...]
+            cls_preds: List(Tensor) [[H x W, C], ...]
+            box_preds: List(Tensor) [[H x W, 4], ...]
+            anchors:   List(Tensor) [[H x W, 2], ...]
+        """
+        all_scores = []
+        all_labels = []
+        all_bboxes = []
+        
+        for obj_pred_i, cls_pred_i, box_pred_i in zip(obj_preds, cls_preds, box_preds):
+            # (H x W x KA x C,)
+            scores_i = (torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid())).flatten()
+
+            # Keep top k top scoring indices only.
+            num_topk = min(self.topk, box_pred_i.size(0))
+
+            # torch.sort is actually faster than .topk (at least on GPUs)
+            predicted_prob, topk_idxs = scores_i.sort(descending=True)
+            topk_scores = predicted_prob[:num_topk]
+            topk_idxs = topk_idxs[:num_topk]
+
+            # filter out the proposals with low confidence score
+            keep_idxs = topk_scores > self.conf_thresh
+            scores = topk_scores[keep_idxs]
+            topk_idxs = topk_idxs[keep_idxs]
+
+            anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+            labels = topk_idxs % self.num_classes
+
+            bboxes = box_pred_i[anchor_idxs]
+
+            all_scores.append(scores)
+            all_labels.append(labels)
+            all_bboxes.append(bboxes)
+
+        scores = torch.cat(all_scores)
+        labels = torch.cat(all_labels)
+        bboxes = torch.cat(all_bboxes)
+
+        # to cpu & numpy
+        scores = scores.cpu().numpy()
+        labels = labels.cpu().numpy()
+        bboxes = bboxes.cpu().numpy()
+
+        # nms
+        scores, labels, bboxes = multiclass_nms(
+            scores, labels, bboxes, self.nms_thresh, self.num_classes, False)
+
+        return bboxes, scores, labels
+
+    
+    # ---------------------- Main Process for Inference ----------------------
+    @torch.no_grad()
+    def inference_single_image(self, x):
+        # ---------------- Backbone ----------------
+        pyramid_feats = self.backbone(x)
+
+        # ---------------- Neck: SPP ----------------
+        pyramid_feats[-1] = self.neck(pyramid_feats[-1])
+
+        # ---------------- Neck: PaFPN ----------------
+        pyramid_feats = self.fpn(pyramid_feats)
+
+        # ---------------- Heads ----------------
+        cls_feats, reg_feats = self.group_heads(pyramid_feats)
+
+        # ---------------- Preds ----------------
+        all_obj_preds = []
+        all_cls_preds = []
+        all_box_preds = []
+        for level, (cls_feat, reg_feat) in enumerate(zip(cls_feats, reg_feats)):
+            # prediction
+            obj_pred = self.obj_preds[level](reg_feat)
+            cls_pred = self.cls_preds[level](cls_feat)
+            reg_pred = self.reg_preds[level](reg_feat)
+            
+            # anchors: [M, 2]
+            fmp_size = cls_feat.shape[-2:]
+            anchors = self.generate_anchors(level, fmp_size)
+            
+            # [1, C, H, W] -> [H, W, C] -> [M, C]
+            obj_pred = obj_pred[0].permute(1, 2, 0).contiguous().view(-1, 1)
+            cls_pred = cls_pred[0].permute(1, 2, 0).contiguous().view(-1, self.num_classes)
+            reg_pred = reg_pred[0].permute(1, 2, 0).contiguous().view(-1, 4)
+
+            # decode bbox
+            ctr_pred = reg_pred[..., :2] * self.stride[level] + anchors[..., :2]
+            wh_pred = torch.exp(reg_pred[..., 2:]) * self.stride[level]
+            pred_x1y1 = ctr_pred - wh_pred * 0.5
+            pred_x2y2 = ctr_pred + wh_pred * 0.5
+            box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
+            all_obj_preds.append(obj_pred)
+            all_cls_preds.append(cls_pred)
+            all_box_preds.append(box_pred)
+
+        if self.deploy:
+            obj_preds = torch.cat(all_obj_preds, dim=0)
+            cls_preds = torch.cat(all_cls_preds, dim=0)
+            box_preds = torch.cat(all_box_preds, dim=0)
+            scores = torch.sqrt(obj_preds.sigmoid() * cls_preds.sigmoid())
+            bboxes = box_preds
+            # [n_anchors_all, 4 + C]
+            outputs = torch.cat([bboxes, scores], dim=-1)
+
+            return outputs
+        else:
+            # post process
+            bboxes, scores, labels = self.post_process(
+                all_obj_preds, all_cls_preds, all_box_preds)
+        
+            return bboxes, scores, labels
+
+
+    # ---------------------- Main Process for Training ----------------------
+    def forward(self, x):
+        if not self.trainable:
+            return self.inference_single_image(x)
+        else:
+            # ---------------- Backbone ----------------
+            pyramid_feats = self.backbone(x)
+
+            # ---------------- Neck: SPP ----------------
+            pyramid_feats[-1] = self.neck(pyramid_feats[-1])
+
+            # ---------------- Neck: PaFPN ----------------
+            pyramid_feats = self.fpn(pyramid_feats)
+
+            # ---------------- Heads ----------------
+            cls_feats, reg_feats = self.group_heads(pyramid_feats)
+
+            # ---------------- Preds ----------------
+            all_anchors = []
+            all_obj_preds = []
+            all_cls_preds = []
+            all_box_preds = []
+            for level, (cls_feat, reg_feat) in enumerate(zip(cls_feats, reg_feats)):
+                # prediction
+                obj_pred = self.obj_preds[level](reg_feat)
+                cls_pred = self.cls_preds[level](cls_feat)
+                reg_pred = self.reg_preds[level](reg_feat)
+
+                B, _, H, W = cls_pred.size()
+                fmp_size = [H, W]
+                # generate anchor boxes: [M, 4]
+                anchors = self.generate_anchors(level, fmp_size)
+                
+                # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
+                obj_pred = obj_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 1)
+                cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+                reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
+
+                # decode bbox
+                ctr_pred = reg_pred[..., :2] * self.stride[level] + anchors[..., :2]
+                wh_pred = torch.exp(reg_pred[..., 2:]) * self.stride[level]
+                pred_x1y1 = ctr_pred - wh_pred * 0.5
+                pred_x2y2 = ctr_pred + wh_pred * 0.5
+                box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
+                all_obj_preds.append(obj_pred)
+                all_cls_preds.append(cls_pred)
+                all_box_preds.append(box_pred)
+                all_anchors.append(anchors)
+            
+            # output dict
+            outputs = {"pred_obj": all_obj_preds,        # List(Tensor) [B, M, 1]
+                       "pred_cls": all_cls_preds,        # List(Tensor) [B, M, C]
+                       "pred_box": all_box_preds,        # List(Tensor) [B, M, 4]
+                       "anchors": all_anchors,           # List(Tensor) [B, M, 2]
+                       'strides': self.stride}           # List(Int) [8, 16, 32]
+
+            return outputs 

+ 154 - 0
models/detectors/e2eyolo/e2eyolo_backbone.py

@@ -0,0 +1,154 @@
+import torch
+import torch.nn as nn
+try:
+    from .e2eyolo_basic import Conv, ELANBlock, DownSample
+except:
+    from e2eyolo_basic import Conv, ELANBlock, DownSample
+
+
+
+model_urls = {
+    'elannet_pico': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_pico.pth",
+    'elannet_nano': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_nano.pth",
+    'elannet_small': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_small.pth",
+    'elannet_medium': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_medium.pth",
+    'elannet_large': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_large.pth",
+    'elannet_huge': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_huge.pth",
+}
+
+
+# ---------------------------- Backbones ----------------------------
+# ELANNet-P5
+class ELANNet(nn.Module):
+    def __init__(self, width=1.0, depth=1.0, act_type='silu', norm_type='BN', depthwise=False):
+        super(ELANNet, self).__init__()
+        self.feat_dims = [int(512 * width), int(1024 * width), int(1024 * width)]
+        
+        # P1/2
+        self.layer_1 = nn.Sequential(
+            Conv(3, int(64*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
+            Conv(int(64*width), int(64*width), k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        # P2/4
+        self.layer_2 = nn.Sequential(   
+            Conv(int(64*width), int(128*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),             
+            ELANBlock(in_dim=int(128*width), out_dim=int(256*width), expand_ratio=0.5, depth=depth,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        # P3/8
+        self.layer_3 = nn.Sequential(
+            DownSample(in_dim=int(256*width), out_dim=int(256*width), act_type=act_type, norm_type=norm_type),             
+            ELANBlock(in_dim=int(256*width), out_dim=int(512*width), expand_ratio=0.5, depth=depth,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        # P4/16
+        self.layer_4 = nn.Sequential(
+            DownSample(in_dim=int(512*width), out_dim=int(512*width), act_type=act_type, norm_type=norm_type),             
+            ELANBlock(in_dim=int(512*width), out_dim=int(1024*width), expand_ratio=0.5, depth=depth,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        # P5/32
+        self.layer_5 = nn.Sequential(
+            DownSample(in_dim=int(1024*width), out_dim=int(1024*width), act_type=act_type, norm_type=norm_type),             
+            ELANBlock(in_dim=int(1024*width), out_dim=int(1024*width), expand_ratio=0.25, depth=depth,
+                    act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+
+
+    def forward(self, x):
+        c1 = self.layer_1(x)
+        c2 = self.layer_2(c1)
+        c3 = self.layer_3(c2)
+        c4 = self.layer_4(c3)
+        c5 = self.layer_5(c4)
+
+        outputs = [c3, c4, c5]
+
+        return outputs
+
+
+# ---------------------------- Functions ----------------------------
+## load pretrained weight
+def load_weight(model, model_name):
+    # load weight
+    print('Loading pretrained weight ...')
+    url = model_urls[model_name]
+    if url is not None:
+        checkpoint = torch.hub.load_state_dict_from_url(
+            url=url, map_location="cpu", check_hash=True)
+        # checkpoint state dict
+        checkpoint_state_dict = checkpoint.pop("model")
+        # model state dict
+        model_state_dict = model.state_dict()
+        # check
+        for k in list(checkpoint_state_dict.keys()):
+            if k in model_state_dict:
+                shape_model = tuple(model_state_dict[k].shape)
+                shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
+                if shape_model != shape_checkpoint:
+                    checkpoint_state_dict.pop(k)
+            else:
+                checkpoint_state_dict.pop(k)
+                print(k)
+
+        model.load_state_dict(checkpoint_state_dict)
+    else:
+        print('No pretrained for {}'.format(model_name))
+
+    return model
+
+
+## build ELAN-Net
+def build_backbone(cfg, pretrained=False): 
+    # model
+    backbone = ELANNet(
+        width=cfg['width'],
+        depth=cfg['depth'],
+        act_type=cfg['bk_act'],
+        norm_type=cfg['bk_norm'],
+        depthwise=cfg['bk_dpw']
+        )
+    # check whether to load imagenet pretrained weight
+    if pretrained:
+        if cfg['width'] == 0.25 and cfg['depth'] == 0.34 and cfg['bk_dpw']:
+            backbone = load_weight(backbone, model_name='elannet_pico')
+        elif cfg['width'] == 0.25 and cfg['depth'] == 0.34:
+            backbone = load_weight(backbone, model_name='elannet_nano')
+        elif cfg['width'] == 0.5 and cfg['depth'] == 0.34:
+            backbone = load_weight(backbone, model_name='elannet_small')
+        elif cfg['width'] == 0.75 and cfg['depth'] == 0.67:
+            backbone = load_weight(backbone, model_name='elannet_medium')
+        elif cfg['width'] == 1.0 and cfg['depth'] == 1.0:
+            backbone = load_weight(backbone, model_name='elannet_large')
+        elif cfg['width'] == 1.25 and cfg['depth'] == 1.34:
+            backbone = load_weight(backbone, model_name='elannet_huge')
+    feat_dims = backbone.feat_dims
+
+    return backbone, feat_dims
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        'pretrained': True,
+        'bk_act': 'silu',
+        'bk_norm': 'BN',
+        'bk_dpw': True,
+        'width': 0.25,
+        'depth': 0.34,
+    }
+    model, feats = build_backbone(cfg)
+    x = torch.randn(1, 3, 640, 640)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    for out in outputs:
+        print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 191 - 0
models/detectors/e2eyolo/e2eyolo_basic.py

@@ -0,0 +1,191 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+
+# ---------------------------- 2D CNN ----------------------------
+class SiLU(nn.Module):
+    """export-friendly version of nn.SiLU()"""
+
+    @staticmethod
+    def forward(x):
+        return x * torch.sigmoid(x)
+
+
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type is None:
+        return nn.Identity()
+
+
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+
+
+# Basic conv layer
+class Conv(nn.Module):
+    def __init__(self, 
+                 c1,                   # in channels
+                 c2,                   # out channels 
+                 k=1,                  # kernel size 
+                 p=0,                  # padding
+                 s=1,                  # padding
+                 d=1,                  # dilation
+                 act_type='lrelu',     # activation
+                 norm_type='BN',       # normalization
+                 depthwise=False):
+        super(Conv, self).__init__()
+        convs = []
+        add_bias = False if norm_type else True
+        p = p if d == 1 else d
+        if depthwise:
+            convs.append(get_conv2d(c1, c1, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
+            # depthwise conv
+            if norm_type:
+                convs.append(get_norm(norm_type, c1))
+            if act_type:
+                convs.append(get_activation(act_type))
+            # pointwise conv
+            convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+
+        else:
+            convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+            
+        self.convs = nn.Sequential(*convs)
+
+
+    def forward(self, x):
+        return self.convs(x)
+
+
+# ---------------------------- Modified YOLOv7's Modules ----------------------------
+## ELANBlock
+class ELANBlock(nn.Module):
+    def __init__(self, in_dim, out_dim, expand_ratio=0.5, depth=1.0, act_type='silu', norm_type='BN', depthwise=False):
+        super(ELANBlock, self).__init__()
+        if isinstance(expand_ratio, float):
+            inter_dim = int(in_dim * expand_ratio)
+            inter_dim2 = inter_dim
+        elif isinstance(expand_ratio, list):
+            assert len(expand_ratio) == 2
+            e1, e2 = expand_ratio
+            inter_dim = int(in_dim * e1)
+            inter_dim2 = int(inter_dim * e2)
+        # branch-1
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        # branch-2
+        self.cv2 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        # branch-3
+        for idx in range(round(3*depth)):
+            if idx == 0:
+                cv3 = [Conv(inter_dim, inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)]
+            else:
+                cv3.append(Conv(inter_dim2, inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise))
+        self.cv3 = nn.Sequential(*cv3)
+        # branch-4
+        self.cv4 = nn.Sequential(*[
+            Conv(inter_dim2, inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+            for _ in range(round(3*depth))
+        ])
+        # output
+        self.out = Conv(inter_dim*2 + inter_dim2*2, out_dim, k=1, act_type=act_type, norm_type=norm_type)
+
+
+    def forward(self, x):
+        """
+        Input:
+            x: [B, C_in, H, W]
+        Output:
+            out: [B, C_out, H, W]
+        """
+        x1 = self.cv1(x)
+        x2 = self.cv2(x)
+        x3 = self.cv3(x2)
+        x4 = self.cv4(x3)
+
+        # [B, C, H, W] -> [B, 2C, H, W]
+        out = self.out(torch.cat([x1, x2, x3, x4], dim=1))
+
+        return out
+
+## DownSample
+class DownSample(nn.Module):
+    def __init__(self, in_dim, out_dim, act_type='silu', norm_type='BN', depthwise=False):
+        super().__init__()
+        inter_dim = out_dim // 2
+        self.mp = nn.MaxPool2d((2, 2), 2)
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv2 = nn.Sequential(
+            Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type),
+            Conv(inter_dim, inter_dim, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+
+    def forward(self, x):
+        """
+        Input:
+            x: [B, C, H, W]
+        Output:
+            out: [B, C, H//2, W//2]
+        """
+        # [B, C, H, W] -> [B, C//2, H//2, W//2]
+        x1 = self.cv1(self.mp(x))
+        x2 = self.cv2(x)
+
+        # [B, C, H//2, W//2]
+        out = torch.cat([x1, x2], dim=1)
+
+        return out
+
+
+# ---------------------------- FPN Modules ----------------------------
+## build fpn's core block
+def build_fpn_block(cfg, in_dim, out_dim):
+    if cfg['fpn_core_block'] == 'elanblock':
+        layer = ELANBlock(in_dim=in_dim,
+                          out_dim=out_dim,
+                          expand_ratio=[0.5, 0.5],
+                          depth=cfg['depth'],
+                          act_type=cfg['fpn_act'],
+                          norm_type=cfg['fpn_norm'],
+                          depthwise=cfg['fpn_depthwise']
+                          )
+        
+    return layer
+
+## build fpn's reduce layer
+def build_reduce_layer(cfg, in_dim, out_dim):
+    if cfg['fpn_reduce_layer'] == 'Conv':
+        layer = Conv(in_dim, out_dim, k=1, act_type=cfg['fpn_act'], norm_type=cfg['fpn_norm'])
+        
+    return layer
+
+## build fpn's downsample layer
+def build_downsample_layer(cfg, in_dim, out_dim):
+    if cfg['fpn_downsample_layer'] == 'Conv':
+        layer = Conv(in_dim, out_dim, k=3, s=2, p=1, act_type=cfg['fpn_act'], norm_type=cfg['fpn_norm'])
+        
+    return layer

+ 113 - 0
models/detectors/e2eyolo/e2eyolo_head.py

@@ -0,0 +1,113 @@
+import torch
+import torch.nn as nn
+
+from .e2eyolo_basic import Conv
+
+
+class SingleLevelHead(nn.Module):
+    def __init__(self, in_dim, out_dim, num_classes, num_cls_head, num_reg_head, act_type, norm_type, depthwise):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.in_dim = in_dim
+        self.num_classes = num_classes
+        self.num_cls_head = num_cls_head
+        self.num_reg_head = num_reg_head
+        self.act_type = act_type
+        self.norm_type = norm_type
+        self.depthwise = depthwise
+        
+        # --------- Network Parameters ----------
+        ## cls head
+        cls_feats = []
+        self.cls_out_dim = out_dim
+        for i in range(num_cls_head):
+            if i == 0:
+                cls_feats.append(
+                    Conv(in_dim, self.cls_out_dim, k=3, p=1, s=1, 
+                         act_type=act_type,
+                         norm_type=norm_type,
+                         depthwise=depthwise)
+                        )
+            else:
+                cls_feats.append(
+                    Conv(self.cls_out_dim, self.cls_out_dim, k=3, p=1, s=1, 
+                        act_type=act_type,
+                        norm_type=norm_type,
+                        depthwise=depthwise)
+                        )      
+        ## reg head
+        reg_feats = []
+        self.reg_out_dim = out_dim
+        for i in range(num_reg_head):
+            if i == 0:
+                reg_feats.append(
+                    Conv(in_dim, self.reg_out_dim, k=3, p=1, s=1, 
+                         act_type=act_type,
+                         norm_type=norm_type,
+                         depthwise=depthwise)
+                        )
+            else:
+                reg_feats.append(
+                    Conv(self.reg_out_dim, self.reg_out_dim, k=3, p=1, s=1, 
+                         act_type=act_type,
+                         norm_type=norm_type,
+                         depthwise=depthwise)
+                        )
+        self.cls_feats = nn.Sequential(*cls_feats)
+        self.reg_feats = nn.Sequential(*reg_feats)
+
+
+    def forward(self, x):
+        """
+            in_feats: (Tensor) [B, C, H, W]
+        """
+        cls_feats = self.cls_feats(x)
+        reg_feats = self.reg_feats(x)
+
+        return cls_feats, reg_feats
+    
+
+class MultiLevelHead(nn.Module):
+    def __init__(self, cfg, in_dims, out_dim, num_classes=80):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.in_dims = in_dims
+        self.num_classes = num_classes
+
+        ## ----------- Network Parameters -----------
+        self.det_heads = nn.ModuleList(
+            [SingleLevelHead(
+                in_dim,
+                out_dim,
+                num_classes,
+                cfg['num_cls_head'],
+                cfg['num_reg_head'],
+                cfg['head_act'],
+                cfg['head_norm'],
+                cfg['head_depthwise'])
+                for in_dim in in_dims
+            ])
+
+
+    def forward(self, feats):
+        """
+            feats: List[(Tensor)] [[B, C, H, W], ...]
+        """
+        cls_feats = []
+        reg_feats = []
+        for feat, head in zip(feats, self.det_heads):
+            # ---------------- Pred ----------------
+            cls_feat, reg_feat = head(feat)
+
+            cls_feats.append(cls_feat)
+            reg_feats.append(reg_feat)
+
+        return cls_feats, reg_feats
+    
+
+# build detection head
+def build_head(cfg, in_dim, out_dim, num_classes=80):
+    if cfg['head'] == 'decoupled_head':
+        head = MultiLevelHead(cfg, in_dim, out_dim, num_classes) 
+
+    return head

+ 71 - 0
models/detectors/e2eyolo/e2eyolo_neck.py

@@ -0,0 +1,71 @@
+import torch
+import torch.nn as nn
+from .e2eyolo_basic import Conv
+
+
+# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
+class SPPF(nn.Module):
+    """
+        This code referenced to https://github.com/ultralytics/yolov5
+    """
+    def __init__(self, cfg, in_dim, out_dim, expand_ratio=0.5):
+        super().__init__()
+        inter_dim = int(in_dim * expand_ratio)
+        self.out_dim = out_dim
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=cfg['neck_act'], norm_type=cfg['neck_norm'])
+        self.cv2 = Conv(inter_dim * 4, out_dim, k=1, act_type=cfg['neck_act'], norm_type=cfg['neck_norm'])
+        self.m = nn.MaxPool2d(kernel_size=cfg['pooling_size'], stride=1, padding=cfg['pooling_size'] // 2)
+
+    def forward(self, x):
+        x = self.cv1(x)
+        y1 = self.m(x)
+        y2 = self.m(y1)
+
+        return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
+
+
+# SPPF block with CSP module
+class SPPFBlockCSP(nn.Module):
+    """
+        CSP Spatial Pyramid Pooling Block
+    """
+    def __init__(self, cfg, in_dim, out_dim, expand_ratio):
+        super(SPPFBlockCSP, self).__init__()
+        inter_dim = int(in_dim * expand_ratio)
+        self.out_dim = out_dim
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=cfg['neck_act'], norm_type=cfg['neck_norm'])
+        self.cv2 = Conv(in_dim, inter_dim, k=1, act_type=cfg['neck_act'], norm_type=cfg['neck_norm'])
+        self.m = nn.Sequential(
+            Conv(inter_dim, inter_dim, k=3, p=1, 
+                 act_type=cfg['neck_act'], norm_type=cfg['neck_norm'], 
+                 depthwise=cfg['neck_depthwise']),
+            SPPF(cfg, inter_dim, inter_dim, expand_ratio=1.0),
+            Conv(inter_dim, inter_dim, k=3, p=1, 
+                 act_type=cfg['neck_act'], norm_type=cfg['neck_norm'], 
+                 depthwise=cfg['neck_depthwise'])
+        )
+        self.cv3 = Conv(inter_dim * 2, self.out_dim, k=1, act_type=cfg['neck_act'], norm_type=cfg['neck_norm'])
+
+        
+    def forward(self, x):
+        x1 = self.cv1(x)
+        x2 = self.cv2(x)
+        x3 = self.m(x2)
+        y = self.cv3(torch.cat([x1, x3], dim=1))
+
+        return y
+
+
+# build neck
+def build_neck(cfg, in_dim, out_dim):
+    model = cfg['neck']
+    print('==============================')
+    print('Neck: {}'.format(model))
+    # build neck
+    if model == 'sppf':
+        neck = SPPF(cfg, in_dim, out_dim, cfg['neck_expand_ratio'])
+    elif model == 'csp_sppf':
+        neck = SPPFBlockCSP(cfg, in_dim, out_dim, cfg['neck_expand_ratio'])
+
+    return neck
+        

+ 94 - 0
models/detectors/e2eyolo/e2eyolo_pafpn.py

@@ -0,0 +1,94 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .e2eyolo_basic import (Conv, build_reduce_layer, build_downsample_layer, build_fpn_block)
+
+
+# YOLO-Style PaFPN
+class YoloPaFPN(nn.Module):
+    def __init__(self, cfg, in_dims=[256, 512, 1024], out_dim=None):
+        super(YoloPaFPN, self).__init__()
+        # --------------------------- Basic Parameters ---------------------------
+        self.in_dims = in_dims
+        c3, c4, c5 = in_dims
+        width = cfg['width']
+
+        # --------------------------- Network Parameters ---------------------------
+        ## top dwon
+        ### P5 -> P4
+        self.reduce_layer_1 = build_reduce_layer(cfg, c5, round(512*width))
+        self.reduce_layer_2 = build_reduce_layer(cfg, c4, round(512*width))
+        self.top_down_layer_1 = build_fpn_block(cfg, round(512*width) + round(512*width), round(512*width))
+
+        ### P4 -> P3
+        self.reduce_layer_3 = build_reduce_layer(cfg, round(512*width), round(256*width))
+        self.reduce_layer_4 = build_reduce_layer(cfg, c3, round(256*width))
+        self.top_down_layer_2 = build_fpn_block(cfg, round(256*width) + round(256*width), round(256*width))
+
+        ## bottom up
+        ### P3 -> P4
+        self.downsample_layer_1 = build_downsample_layer(cfg, round(256*width), round(256*width))
+        self.bottom_up_layer_1 = build_fpn_block(cfg, round(256*width) + round(256*width), round(512*width))
+
+        ### P4 -> P5
+        self.downsample_layer_2 = build_downsample_layer(cfg, round(512*width), round(512*width))
+        self.bottom_up_layer_2 = build_fpn_block(cfg, round(512*width) + round(512*width), round(1024*width))
+                
+        ## output proj layers
+        if out_dim is not None:
+            self.out_layers = nn.ModuleList([
+                Conv(in_dim, out_dim, k=1,
+                     act_type=cfg['fpn_act'], norm_type=cfg['fpn_norm'])
+                     for in_dim in [round(256*width), round(512*width), round(1024*width)]
+                     ])
+            self.out_dim = [out_dim] * 3
+        else:
+            self.out_layers = None
+            self.out_dim = [round(256*width), round(512*width), round(1024*width)]
+
+
+    def forward(self, features):
+        c3, c4, c5 = features
+
+        # Top down
+        ## P5 -> P4
+        c6 = self.reduce_layer_1(c5)
+        c7 = self.reduce_layer_2(c4)
+        c8 = torch.cat([F.interpolate(c6, scale_factor=2.0), c7], dim=1)
+        c9 = self.top_down_layer_1(c8)
+        ## P4 -> P3
+        c10 = self.reduce_layer_3(c9)
+        c11 = self.reduce_layer_4(c3)
+        c12 = torch.cat([F.interpolate(c10, scale_factor=2.0), c11], dim=1)
+        c13 = self.top_down_layer_2(c12)
+
+        # Bottom up
+        # p3 -> P4
+        c14 = self.downsample_layer_1(c13)
+        c15 = torch.cat([c14, c10], dim=1)
+        c16 = self.bottom_up_layer_1(c15)
+        # P4 -> P5
+        c17 = self.downsample_layer_2(c16)
+        c18 = torch.cat([c17, c6], dim=1)
+        c19 = self.bottom_up_layer_2(c18)
+
+        out_feats = [c13, c16, c19] # [P3, P4, P5]
+        
+        # output proj layers
+        if self.out_layers is not None:
+            out_feats_proj = []
+            for feat, layer in zip(out_feats, self.out_layers):
+                out_feats_proj.append(layer(feat))
+            return out_feats_proj
+
+        return out_feats
+
+
+def build_fpn(cfg, in_dims, out_dim=None):
+    model = cfg['fpn']
+    # build pafpn
+    if model == 'yolo_pafpn':
+        fpn_net = YoloPaFPN(cfg, in_dims, out_dim)
+
+    return fpn_net

+ 170 - 0
models/detectors/e2eyolo/loss.py

@@ -0,0 +1,170 @@
+import torch
+import torch.nn.functional as F
+from .matcher import TaskAlignedAssigner
+from utils.box_ops import get_ious
+from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
+
+
+
+class Criterion(object):
+    def __init__(self, 
+                 cfg, 
+                 device, 
+                 num_classes=80):
+        self.cfg = cfg
+        self.device = device
+        self.num_classes = num_classes
+        # loss weight
+        self.loss_obj_weight = cfg['loss_obj_weight']
+        self.loss_cls_weight = cfg['loss_cls_weight']
+        self.loss_box_weight = cfg['loss_box_weight']
+        # matcher
+        matcher_config = cfg['matcher']
+        self.matcher = TaskAlignedAssigner(
+            topk=matcher_config['topk'],
+            num_classes=num_classes,
+            alpha=matcher_config['alpha'],
+            beta=matcher_config['beta']
+            )
+
+
+    def loss_objectness(self, pred_obj, gt_obj):
+        loss_obj = F.binary_cross_entropy_with_logits(pred_obj, gt_obj, reduction='none')
+
+        return loss_obj
+    
+
+    def loss_classes(self, pred_cls, gt_label):
+        loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_label, reduction='none')
+
+        return loss_cls
+
+
+    def loss_bboxes(self, pred_box, gt_box):
+        # regression loss
+        ious = get_ious(pred_box,
+                        gt_box,
+                        box_mode="xyxy",
+                        iou_type='giou')
+        loss_box = 1.0 - ious
+
+        return loss_box
+
+
+    def __call__(self, outputs, targets):        
+        """
+            outputs['pred_cls']: List(Tensor) [B, M, C]
+            outputs['pred_regs']: List(Tensor) [B, M, 4*(reg_max+1)]
+            outputs['pred_boxs']: List(Tensor) [B, M, 4]
+            outputs['anchors']: List(Tensor) [M, 2]
+            outputs['strides']: List(Int) [8, 16, 32] output stride
+            outputs['stride_tensor']: List(Tensor) [M, 1]
+            targets: (List) [dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}, ...]
+        """
+        bs = outputs['pred_cls'][0].shape[0]
+        device = outputs['pred_cls'][0].device
+        anchors = torch.cat(outputs['anchors'], dim=0)
+        num_anchors = anchors.shape[0]
+
+        # preds: [B, M, C]
+        obj_preds = torch.cat(outputs['pred_obj'], dim=1)
+        cls_preds = torch.cat(outputs['pred_cls'], dim=1)
+        box_preds = torch.cat(outputs['pred_box'], dim=1)
+        
+        # label assignment
+        gt_label_targets = []
+        gt_score_targets = []
+        gt_bbox_targets = []
+        fg_masks = []
+
+        for batch_idx in range(bs):
+            tgt_labels = targets[batch_idx]["labels"].to(device)     # [Mp,]
+            tgt_boxs = targets[batch_idx]["boxes"].to(device)        # [Mp, 4]
+
+            # check target
+            if len(tgt_labels) == 0 or tgt_boxs.max().item() == 0.:
+                # There is no valid gt
+                fg_mask = cls_preds.new_zeros(1, num_anchors).bool()               #[1, M,]
+                gt_label = cls_preds.new_zeros((1, num_anchors,))                  #[1, M,]
+                gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
+                gt_box = cls_preds.new_zeros((1, num_anchors, 4))                  #[1, M, 4]
+            else:
+                tgt_labels = tgt_labels[None, :, None]      # [1, Mp, 1]
+                tgt_boxs = tgt_boxs[None]                   # [1, Mp, 4]
+                (
+                    gt_label,   #[1, M]
+                    gt_box,     #[1, M, 4]
+                    gt_score,   #[1, M, C]
+                    fg_mask,    #[1, M,]
+                    _
+                ) = self.matcher(
+                    pd_scores = torch.sqrt(obj_preds[batch_idx:batch_idx+1].sigmoid() * \
+                                           cls_preds[batch_idx:batch_idx+1].sigmoid()).detach(), 
+                    pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
+                    anc_points = anchors,
+                    gt_labels = tgt_labels,
+                    gt_bboxes = tgt_boxs
+                    )
+            gt_label_targets.append(gt_label)
+            gt_score_targets.append(gt_score)
+            gt_bbox_targets.append(gt_box)
+            fg_masks.append(fg_mask)
+
+        # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
+        fg_masks = torch.cat(fg_masks, 0).view(-1)                                    # [BM,]
+        gt_label_targets = torch.cat(gt_label_targets, 0).view(-1)                    # [BM,]
+        gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes)  # [BM, C]
+        gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4)                   # [BM, 4]
+
+        obj_targets = fg_masks.unsqueeze(-1)        # [M, 1]
+        cls_targets = gt_score_targets[fg_masks]    # [Mp, C]
+        box_targets = gt_bbox_targets[fg_masks]     # [Mp, 4]
+        num_fgs = fg_masks.sum()
+        
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_fgs)
+        num_fgs = (num_fgs / get_world_size()).clamp(1.0)
+
+        # obj loss
+        loss_obj = self.loss_objectness(obj_preds.view(-1, 1), obj_targets.float())
+        loss_obj = loss_obj.sum() / num_fgs
+        
+        # cls loss
+        cls_preds_pos = cls_preds.view(-1, self.num_classes)[fg_masks]
+        loss_cls = self.loss_classes(cls_preds_pos, cls_targets)
+        loss_cls = loss_cls.sum() / num_fgs
+
+        # regression loss
+        box_preds_pos = box_preds.view(-1, 4)[fg_masks]
+        loss_box = self.loss_bboxes(box_preds_pos, box_targets)
+        loss_box = loss_box.sum() / num_fgs
+
+        # total loss
+        losses = self.loss_obj_weight * loss_obj + \
+                 self.loss_cls_weight * loss_cls + \
+                 self.loss_box_weight * loss_box
+
+        loss_dict = dict(
+                loss_obj = loss_obj,
+                loss_cls = loss_cls,
+                loss_box = loss_box,
+                losses = losses
+        )
+
+        return loss_dict
+    
+
+def build_criterion(cfg, device, num_classes):
+    criterion = Criterion(
+        cfg=cfg,
+        device=device,
+        num_classes=num_classes
+        )
+
+    return criterion
+
+
+if __name__ == "__main__":
+    pass

+ 203 - 0
models/detectors/e2eyolo/matcher.py

@@ -0,0 +1,203 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from utils.box_ops import bbox_iou
+
+
+# -------------------------- Task Aligned Assigner --------------------------
+class TaskAlignedAssigner(nn.Module):
+    def __init__(self,
+                 topk=10,
+                 num_classes=80,
+                 alpha=0.5,
+                 beta=6.0, 
+                 eps=1e-9):
+        super(TaskAlignedAssigner, self).__init__()
+        self.topk = topk
+        self.num_classes = num_classes
+        self.bg_idx = num_classes
+        self.alpha = alpha
+        self.beta = beta
+        self.eps = eps
+
+    @torch.no_grad()
+    def forward(self,
+                pd_scores,
+                pd_bboxes,
+                anc_points,
+                gt_labels,
+                gt_bboxes):
+        """This code referenced to
+           https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
+        Args:
+            pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
+            pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
+            anc_points (Tensor): shape(num_total_anchors, 2)
+            gt_labels (Tensor): shape(bs, n_max_boxes, 1)
+            gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
+        Returns:
+            target_labels (Tensor): shape(bs, num_total_anchors)
+            target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
+            target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
+            fg_mask (Tensor): shape(bs, num_total_anchors)
+        """
+        self.bs = pd_scores.size(0)
+        self.n_max_boxes = gt_bboxes.size(1)
+
+        mask_pos, align_metric, overlaps = self.get_pos_mask(
+            pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points)
+
+        target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(
+            mask_pos, overlaps, self.n_max_boxes)
+
+        # assigned target
+        target_labels, target_bboxes, target_scores = self.get_targets(
+            gt_labels, gt_bboxes, target_gt_idx, fg_mask)
+
+        # normalize
+        align_metric *= mask_pos
+        pos_align_metrics = align_metric.amax(axis=-1, keepdim=True)  # b, max_num_obj
+        pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True)  # b, max_num_obj
+        norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
+        target_scores = target_scores * norm_align_metric
+
+        return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
+
+
+    def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points):
+        # get anchor_align metric, (b, max_num_obj, h*w)
+        align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes)
+        # get in_gts mask, (b, max_num_obj, h*w)
+        mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
+        # get topk_metric mask, (b, max_num_obj, h*w)
+        mask_topk = self.select_topk_candidates(align_metric * mask_in_gts)
+        # merge all mask to a final mask, (b, max_num_obj, h*w)
+        mask_pos = mask_topk * mask_in_gts
+
+        return mask_pos, align_metric, overlaps
+
+
+    def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes):
+        ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
+        ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes)  # b, max_num_obj
+        ind[1] = gt_labels.long().squeeze(-1)  # b, max_num_obj
+        # get the scores of each grid for each gt cls
+        bbox_scores = pd_scores[ind[0], :, ind[1]]  # b, max_num_obj, h*w
+
+        overlaps = bbox_iou(gt_bboxes.unsqueeze(2), pd_bboxes.unsqueeze(1), xywh=False).squeeze(3).clamp(0)
+        align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
+
+        return align_metric, overlaps
+
+
+    def select_topk_candidates(self, metrics, largest=True):
+        """
+        Args:
+            metrics: (b, max_num_obj, h*w).
+            topk_mask: (b, max_num_obj, topk) or None
+        """
+
+        num_anchors = metrics.shape[-1]  # h*w
+        # (b, max_num_obj, topk)
+        topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
+        topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).tile([1, 1, self.topk])
+        # (b, max_num_obj, topk)
+        topk_idxs[~topk_mask] = 0
+        # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
+        is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2)
+        # filter invalid bboxes
+        is_in_topk = torch.where(is_in_topk > 1, 0, is_in_topk)
+        return is_in_topk.to(metrics.dtype)
+
+
+    def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
+        """
+        Args:
+            gt_labels: (b, max_num_obj, 1)
+            gt_bboxes: (b, max_num_obj, 4)
+            target_gt_idx: (b, h*w)
+            fg_mask: (b, h*w)
+        """
+
+        # assigned target labels, (b, 1)
+        batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
+        target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes  # (b, h*w)
+        target_labels = gt_labels.long().flatten()[target_gt_idx]  # (b, h*w)
+
+        # assigned target boxes, (b, max_num_obj, 4) -> (b, h*w)
+        target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
+
+        # assigned target scores
+        target_labels.clamp(0)
+        target_scores = F.one_hot(target_labels, self.num_classes)  # (b, h*w, 80)
+        fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)  # (b, h*w, 80)
+        target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
+
+        return target_labels, target_bboxes, target_scores
+    
+
+# -------------------------- Basic Functions --------------------------
+def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
+    """select the positive anchors's center in gt
+    Args:
+        xy_centers (Tensor): shape(bs*n_max_boxes, num_total_anchors, 4)
+        gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
+    Return:
+        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    n_anchors = xy_centers.size(0)
+    bs, n_max_boxes, _ = gt_bboxes.size()
+    _gt_bboxes = gt_bboxes.reshape([-1, 4])
+    xy_centers = xy_centers.unsqueeze(0).repeat(bs * n_max_boxes, 1, 1)
+    gt_bboxes_lt = _gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, n_anchors, 1)
+    gt_bboxes_rb = _gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, n_anchors, 1)
+    b_lt = xy_centers - gt_bboxes_lt
+    b_rb = gt_bboxes_rb - xy_centers
+    bbox_deltas = torch.cat([b_lt, b_rb], dim=-1)
+    bbox_deltas = bbox_deltas.reshape([bs, n_max_boxes, n_anchors, -1])
+    return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)
+
+
+def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
+    """if an anchor box is assigned to multiple gts,
+        the one with the highest iou will be selected.
+    Args:
+        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+        overlaps (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    Return:
+        target_gt_idx (Tensor): shape(bs, num_total_anchors)
+        fg_mask (Tensor): shape(bs, num_total_anchors)
+        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    fg_mask = mask_pos.sum(axis=-2)
+    if fg_mask.max() > 1:
+        mask_multi_gts = (fg_mask.unsqueeze(1) > 1).repeat([1, n_max_boxes, 1])
+        max_overlaps_idx = overlaps.argmax(axis=1)
+        is_max_overlaps = F.one_hot(max_overlaps_idx, n_max_boxes)
+        is_max_overlaps = is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype)
+        mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos)
+        fg_mask = mask_pos.sum(axis=-2)
+    target_gt_idx = mask_pos.argmax(axis=-2)
+    return target_gt_idx, fg_mask , mask_pos
+
+
+def iou_calculator(box1, box2, eps=1e-9):
+    """Calculate iou for batch
+    Args:
+        box1 (Tensor): shape(bs, n_max_boxes, 1, 4)
+        box2 (Tensor): shape(bs, 1, num_total_anchors, 4)
+    Return:
+        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    box1 = box1.unsqueeze(2)  # [N, M1, 4] -> [N, M1, 1, 4]
+    box2 = box2.unsqueeze(1)  # [N, M2, 4] -> [N, 1, M2, 4]
+    px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
+    gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
+    x1y1 = torch.maximum(px1y1, gx1y1)
+    x2y2 = torch.minimum(px2y2, gx2y2)
+    overlap = (x2y2 - x1y1).clip(0).prod(-1)
+    area1 = (px2y2 - px1y1).clip(0).prod(-1)
+    area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
+    union = area1 + area2 - overlap + eps
+
+    return overlap / union

+ 14 - 13
models/detectors/yolox2/build.py

@@ -33,23 +33,24 @@ def build_yolox2(args, cfg, device, num_classes=80, trainable=False, deploy=Fals
     # Init head
     init_prob = 0.01
     bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
-    for det_head in model.det_heads:
-        # obj pred
-        b = det_head.obj_pred.bias.view(1, -1)
+    ## obj pred
+    for obj_pred in model.obj_preds:
+        b = obj_pred.bias.view(1, -1)
         b.data.fill_(bias_value.item())
-        det_head.obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        # cls pred
-        b = det_head.cls_pred.bias.view(1, -1)
+        obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+    ## cls pred
+    for cls_pred in model.cls_preds:
+        b = cls_pred.bias.view(1, -1)
         b.data.fill_(bias_value.item())
-        det_head.cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        # reg pred
-        b = det_head.reg_pred.bias.view(-1, )
+        cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+    ## reg pred
+    for reg_pred in model.reg_preds:
+        b = reg_pred.bias.view(-1, )
         b.data.fill_(1.0)
-        det_head.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        w = det_head.reg_pred.weight
+        reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        w = reg_pred.weight
         w.data.fill_(0.)
-        det_head.reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
-
+        reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
 
     # -------------- Build criterion --------------
     criterion = None

+ 36 - 15
models/detectors/yolox2/yolox2.py

@@ -34,6 +34,7 @@ class YOLOX2(nn.Module):
         self.nms_thresh = nms_thresh
         self.topk = topk
         self.deploy = deploy
+        self.head_dim = round(256*cfg['width'])
         
         # ---------------------- Network Parameters ----------------------
         ## ----------- Backbone -----------
@@ -45,13 +46,24 @@ class YOLOX2(nn.Module):
         
         ## ----------- Neck: FPN -----------
         self.fpn = build_fpn(cfg=cfg, in_dims=feats_dim, out_dim=round(256*cfg['width']))
-        self.head_dim = self.fpn.out_dim
+        self.fpn_dims = self.fpn.out_dim
 
         ## ----------- Heads -----------
-        self.det_heads = nn.ModuleList(
-            [build_head(cfg, head_dim, head_dim, num_classes) 
-            for head_dim in self.head_dim
-            ])
+        self.group_heads = build_head(cfg, self.fpn_dims, self.head_dim, num_classes) 
+
+        ## ----------- Preds -----------
+        self.obj_preds = nn.ModuleList(
+                            [nn.Conv2d(self.head_dim, 1, kernel_size=1) 
+                                for _ in range(len(self.stride))
+                              ]) 
+        self.cls_preds = nn.ModuleList(
+                            [nn.Conv2d(self.head_dim, num_classes, kernel_size=1) 
+                                for _ in range(len(self.stride))
+                              ]) 
+        self.reg_preds = nn.ModuleList(
+                            [nn.Conv2d(self.head_dim, 4, kernel_size=1) 
+                                for _ in range(len(self.stride))
+                              ])                 
 
 
     # ---------------------- Basic Functions ----------------------
@@ -139,17 +151,22 @@ class YOLOX2(nn.Module):
         pyramid_feats = self.fpn(pyramid_feats)
 
         # ---------------- Heads ----------------
+        cls_feats, reg_feats = self.group_heads(pyramid_feats)
+
+        # ---------------- Preds ----------------
         all_obj_preds = []
         all_cls_preds = []
         all_box_preds = []
-        for level, (feat, head) in enumerate(zip(pyramid_feats, self.det_heads)):
-            # ---------------- Pred ----------------
-            obj_pred, cls_pred, reg_pred = head(feat)
-
+        for level, (cls_feat, reg_feat) in enumerate(zip(cls_feats, reg_feats)):
+            # prediction
+            obj_pred = self.obj_preds[level](reg_feat)
+            cls_pred = self.cls_preds[level](cls_feat)
+            reg_pred = self.reg_preds[level](reg_feat)
+            
             # anchors: [M, 2]
-            fmp_size = cls_pred.shape[-2:]
+            fmp_size = cls_feat.shape[-2:]
             anchors = self.generate_anchors(level, fmp_size)
-
+            
             # [1, C, H, W] -> [H, W, C] -> [M, C]
             obj_pred = obj_pred[0].permute(1, 2, 0).contiguous().view(-1, 1)
             cls_pred = cls_pred[0].permute(1, 2, 0).contiguous().view(-1, self.num_classes)
@@ -199,14 +216,18 @@ class YOLOX2(nn.Module):
             pyramid_feats = self.fpn(pyramid_feats)
 
             # ---------------- Heads ----------------
+            cls_feats, reg_feats = self.group_heads(pyramid_feats)
+
+            # ---------------- Preds ----------------
             all_anchors = []
             all_obj_preds = []
             all_cls_preds = []
             all_box_preds = []
-            all_strides = []
-            for level, (feat, head) in enumerate(zip(pyramid_feats, self.det_heads)):
-                # ---------------- Pred ----------------
-                obj_pred, cls_pred, reg_pred = head(feat)
+            for level, (cls_feat, reg_feat) in enumerate(zip(cls_feats, reg_feats)):
+                # prediction
+                obj_pred = self.obj_preds[level](reg_feat)
+                cls_pred = self.cls_preds[level](cls_feat)
+                reg_pred = self.reg_preds[level](reg_feat)
 
                 B, _, H, W = cls_pred.size()
                 fmp_size = [H, W]

+ 63 - 97
models/detectors/yolox2/yolox2_head.py

@@ -1,67 +1,61 @@
 import torch
 import torch.nn as nn
-try:
-    from .yolox2_basic import Conv
-except:
-    from yolox2_basic import Conv
 
+from .yolox2_basic import Conv
 
-class DecoupledHead(nn.Module):
-    def __init__(self, cfg, in_dim, out_dim, num_classes=80):
+
+class SingleLevelHead(nn.Module):
+    def __init__(self, in_dim, out_dim, num_classes, num_cls_head, num_reg_head, act_type, norm_type, depthwise):
         super().__init__()
-        print('==============================')
-        print('Head: Decoupled Head')
         # --------- Basic Parameters ----------
         self.in_dim = in_dim
         self.num_classes = num_classes
-        self.num_cls_head=cfg['num_cls_head']
-        self.num_reg_head=cfg['num_reg_head']
-
+        self.num_cls_head = num_cls_head
+        self.num_reg_head = num_reg_head
+        self.act_type = act_type
+        self.norm_type = norm_type
+        self.depthwise = depthwise
+        
         # --------- Network Parameters ----------
         ## cls head
         cls_feats = []
         self.cls_out_dim = out_dim
-        for i in range(cfg['num_cls_head']):
+        for i in range(num_cls_head):
             if i == 0:
                 cls_feats.append(
                     Conv(in_dim, self.cls_out_dim, k=3, p=1, s=1, 
-                        act_type=cfg['head_act'],
-                        norm_type=cfg['head_norm'],
-                        depthwise=cfg['head_depthwise'])
+                         act_type=act_type,
+                         norm_type=norm_type,
+                         depthwise=depthwise)
                         )
             else:
                 cls_feats.append(
                     Conv(self.cls_out_dim, self.cls_out_dim, k=3, p=1, s=1, 
-                        act_type=cfg['head_act'],
-                        norm_type=cfg['head_norm'],
-                        depthwise=cfg['head_depthwise'])
+                        act_type=act_type,
+                        norm_type=norm_type,
+                        depthwise=depthwise)
                         )      
         ## reg head
         reg_feats = []
         self.reg_out_dim = out_dim
-        for i in range(cfg['num_reg_head']):
+        for i in range(num_reg_head):
             if i == 0:
                 reg_feats.append(
                     Conv(in_dim, self.reg_out_dim, k=3, p=1, s=1, 
-                        act_type=cfg['head_act'],
-                        norm_type=cfg['head_norm'],
-                        depthwise=cfg['head_depthwise'])
+                         act_type=act_type,
+                         norm_type=norm_type,
+                         depthwise=depthwise)
                         )
             else:
                 reg_feats.append(
                     Conv(self.reg_out_dim, self.reg_out_dim, k=3, p=1, s=1, 
-                        act_type=cfg['head_act'],
-                        norm_type=cfg['head_norm'],
-                        depthwise=cfg['head_depthwise'])
+                         act_type=act_type,
+                         norm_type=norm_type,
+                         depthwise=depthwise)
                         )
         self.cls_feats = nn.Sequential(*cls_feats)
         self.reg_feats = nn.Sequential(*reg_feats)
 
-        ## Pred
-        self.obj_pred = nn.Conv2d(self.cls_out_dim, 1, kernel_size=1) 
-        self.cls_pred = nn.Conv2d(self.cls_out_dim, num_classes, kernel_size=1) 
-        self.reg_pred = nn.Conv2d(self.reg_out_dim, 4, kernel_size=1) 
-
 
     def forward(self, x):
         """
@@ -70,78 +64,50 @@ class DecoupledHead(nn.Module):
         cls_feats = self.cls_feats(x)
         reg_feats = self.reg_feats(x)
 
-        obj_pred = self.obj_pred(reg_feats)
-        cls_pred = self.cls_pred(cls_feats)
-        reg_pred = self.reg_pred(reg_feats)
+        return cls_feats, reg_feats
+    
 
-        return obj_pred, cls_pred, reg_pred
+class MultiLevelHead(nn.Module):
+    def __init__(self, cfg, in_dims, out_dim, num_classes=80):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.in_dims = in_dims
+        self.num_classes = num_classes
+
+        ## ----------- Network Parameters -----------
+        self.det_heads = nn.ModuleList(
+            [SingleLevelHead(
+                in_dim,
+                out_dim,
+                num_classes,
+                cfg['num_cls_head'],
+                cfg['num_reg_head'],
+                cfg['head_act'],
+                cfg['head_norm'],
+                cfg['head_depthwise'])
+                for in_dim in in_dims
+            ])
+
+
+    def forward(self, feats):
+        """
+            feats: List[(Tensor)] [[B, C, H, W], ...]
+        """
+        cls_feats = []
+        reg_feats = []
+        for feat, head in zip(feats, self.det_heads):
+            # ---------------- Pred ----------------
+            cls_feat, reg_feat = head(feat)
+
+            cls_feats.append(cls_feat)
+            reg_feats.append(reg_feat)
+
+        return cls_feats, reg_feats
     
 
 # build detection head
 def build_head(cfg, in_dim, out_dim, num_classes=80):
     if cfg['head'] == 'decoupled_head':
-        head = DecoupledHead(cfg, in_dim, out_dim, num_classes) 
+        head = MultiLevelHead(cfg, in_dim, out_dim, num_classes) 
 
     return head
-
-
-if __name__ == '__main__':
-    import time
-    from thop import profile
-    cfg = {
-        'head': 'decoupled_head',
-        'num_cls_head': 2,
-        'num_reg_head': 2,
-        'head_act': 'silu',
-        'head_norm': 'BN',
-        'head_depthwise': False,
-        'reg_max': 16,
-    }
-    fpn_dims = [256, 512, 512]
-    # Head-1
-    model = build_head(cfg, 256, fpn_dims, num_classes=80)
-    x = torch.randn(1, 256, 80, 80)
-    t0 = time.time()
-    outputs = model(x)
-    t1 = time.time()
-    print('Time: ', t1 - t0)
-    # for out in outputs:
-    #     print(out.shape)
-
-    print('==============================')
-    flops, params = profile(model, inputs=(x, ), verbose=False)
-    print('==============================')
-    print('Head-1: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Head-1: Params : {:.2f} M'.format(params / 1e6))
-
-    # Head-2
-    model = build_head(cfg, 512, fpn_dims, num_classes=80)
-    x = torch.randn(1, 512, 40, 40)
-    t0 = time.time()
-    outputs = model(x)
-    t1 = time.time()
-    print('Time: ', t1 - t0)
-    # for out in outputs:
-    #     print(out.shape)
-
-    print('==============================')
-    flops, params = profile(model, inputs=(x, ), verbose=False)
-    print('==============================')
-    print('Head-2: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Head-2: Params : {:.2f} M'.format(params / 1e6))
-
-    # Head-3
-    model = build_head(cfg, 512, fpn_dims, num_classes=80)
-    x = torch.randn(1, 512, 20, 20)
-    t0 = time.time()
-    outputs = model(x)
-    t1 = time.time()
-    print('Time: ', t1 - t0)
-    # for out in outputs:
-    #     print(out.shape)
-
-    print('==============================')
-    flops, params = profile(model, inputs=(x, ), verbose=False)
-    print('==============================')
-    print('Head-3: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Head-3: Params : {:.2f} M'.format(params / 1e6))