yjh0410 1 год назад
Родитель
Сommit
60fbb4ddd7

+ 4 - 0
yolo/config/__init__.py

@@ -8,6 +8,7 @@ from .yolov6_config     import build_yolov6_config
 from .yolov8_config     import build_yolov8_config
 from .yolov8_e2e_config import build_yolov8_e2e_config
 from .gelan_config      import build_gelan_config
+from .rtcdet_config     import build_rtcdet_config
 from .rtdetr_config     import build_rtdetr_config
 
 
@@ -33,6 +34,9 @@ def build_config(args):
         cfg = build_yolov8_config(args)
     elif 'gelan' in args.model:
         cfg = build_gelan_config(args)
+    elif 'rtcdet' in args.model:
+        cfg = build_rtcdet_config(args)
+        
     # ----------- RT-DETR -----------
     elif 'rtdetr' in args.model:
         cfg = build_rtdetr_config(args)

+ 217 - 0
yolo/config/rtcdet_config.py

@@ -0,0 +1,217 @@
+# RTCDet config
+
+
+def build_rtcdet_config(args):
+    if   args.model == 'rtcdet_n':
+        return RTCDet_Nano_Config()
+    elif args.model == 'rtcdet_t':
+        return RTCDet_Tiny_Config()
+    elif args.model == 'rtcdet_s':
+        return RTCDet_Small_Config()
+    elif args.model == 'rtcdet_m':
+        return RTCDet_Medium_Config()
+    elif args.model == 'rtcdet_l':
+        return RTCDet_Large_Config()
+    elif args.model == 'rtcdet_x':
+        return RTCDet_xLarge_Config()
+    else:
+        raise NotImplementedError("No config for model: {}".format(args.model))
+    
+# RTCDet-Base config
+class RTCDetBaseConfig(object):
+    def __init__(self) -> None:
+        # ---------------- Model config ----------------
+        self.stage_dims  = [64, 128, 256, 512, 512]
+        self.stage_depth = [3, 6, 6, 3]
+        self.width    = 1.0
+        self.depth    = 1.0
+        self.ratio    = 1.0
+        self.reg_max  = 16
+        self.out_stride = [8, 16, 32]
+        self.max_stride = 32
+        self.num_levels = 3
+        ## Backbone
+        self.bk_block    = 'elan_layer'
+        self.bk_ds_block = 'conv'
+        self.bk_act      = 'silu'
+        self.bk_norm     = 'bn'
+        self.bk_depthwise   = False
+        ## Neck
+        self.neck_act       = 'silu'
+        self.neck_norm      = 'bn'
+        self.neck_depthwise = False
+        self.neck_expand_ratio = 0.5
+        self.spp_pooling_size  = 5
+        ## FPN
+        self.fpn_block     = 'elan_layer'
+        self.fpn_ds_block  = 'conv'
+        self.fpn_act       = 'silu'
+        self.fpn_norm      = 'bn'
+        self.fpn_depthwise = False
+        ## Head
+        self.head_act  = 'silu'
+        self.head_norm = 'bn'
+        self.head_depthwise = False
+        self.num_cls_head   = 2
+        self.num_reg_head   = 2
+
+        # ---------------- Post-process config ----------------
+        ## Post process
+        self.val_topk = 1000
+        self.val_conf_thresh = 0.001
+        self.val_nms_thresh  = 0.7
+        self.test_topk = 100
+        self.test_conf_thresh = 0.2
+        self.test_nms_thresh  = 0.5
+
+        # ---------------- Assignment & Loss config ----------------
+        self.loss_cls_type = "bce"
+        self.matcher_dict = {"tal_alpha": 0.5, "tal_beta": 6.0, "topk_candidates": 10}
+        self.weight_dict  = {"loss_cls": 0.5, "loss_box": 7.5, "loss_dfl": 1.5}
+
+        # ---------------- Assignment & Loss config ----------------
+        # self.loss_cls_type = "vfl"
+        # self.matcher_dict = {"tal_alpha": 1.0, "tal_beta": 6.0, "topk_candidates": 13}   # For VFL
+        # self.weight_dict  = {"loss_cls": 1.0, "loss_box": 2.5, "loss_dfl": 0.5}   # For VFL
+
+        # ---------------- ModelEMA config ----------------
+        self.use_ema = True
+        self.ema_decay = 0.9998
+        self.ema_tau   = 2000
+
+        # ---------------- Optimizer config ----------------
+        self.trainer      = 'yolo'
+        self.no_norm_decay = True
+        self.no_bias_decay = True
+        self.batch_size_base = 64
+        self.optimizer    = 'adamw'
+        self.base_lr      = 0.001
+        self.min_lr_ratio = 0.05      # min_lr  = base_lr * min_lr_ratio
+        self.momentum     = 0.9
+        self.weight_decay = 0.05
+        self.clip_max_norm   = 35.0
+        self.warmup_bias_lr  = 0.1
+        self.warmup_momentum = 0.8
+        self.use_fp16        = True  # use mixing precision
+
+        # ---------------- Lr Scheduler config ----------------
+        self.warmup_epoch = 3
+        self.lr_scheduler = "cosine"
+        self.max_epoch    = 500
+        self.eval_epoch   = 10
+        self.no_aug_epoch = 15
+
+        # ---------------- Data process config ----------------
+        self.aug_type = 'yolo'
+        self.box_format = 'xyxy'
+        self.normalize_coords = False
+        self.mosaic_prob = 0.0
+        self.mixup_prob  = 0.0
+        self.copy_paste  = 0.0           # approximated by the YOLOX's mixup
+        self.multi_scale = [0.5, 1.5]   # multi scale: [img_size * 0.5, img_size * 1.5]
+        ## Pixel mean & std
+        self.pixel_mean = [0., 0., 0.]
+        self.pixel_std  = [255., 255., 255.]
+        ## Transforms
+        self.train_img_size = 640
+        self.test_img_size  = 640
+        self.affine_params = {
+            'degrees': 0.0,
+            'translate': 0.2,
+            'scale': [0.1, 2.0],
+            'shear': 0.0,
+            'perspective': 0.0,
+            'hsv_h': 0.015,
+            'hsv_s': 0.7,
+            'hsv_v': 0.4,
+        }
+
+    def print_config(self):
+        config_dict = {key: value for key, value in self.__dict__.items() if not key.startswith('__')}
+        for k, v in config_dict.items():
+            print("{} : {}".format(k, v))
+
+# RTCDet-N
+class RTCDet_Nano_Config(RTCDetBaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.width = 0.25
+        self.depth = 0.34
+        self.ratio = 2.0
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.0
+        self.copy_paste  = 1.0
+
+# RTCDet-T
+class RTCDet_Tiny_Config(RTCDetBaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.width = 0.375
+        self.depth = 0.34
+        self.ratio = 2.0
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.0
+        self.copy_paste  = 1.0
+
+# RTCDet-S
+class RTCDet_Small_Config(RTCDetBaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.width = 0.50
+        self.depth = 0.34
+        self.ratio = 2.0
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.05
+        self.copy_paste  = 1.0
+
+# RTCDet-M
+class RTCDet_Medium_Config(RTCDetBaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.width = 0.75
+        self.depth = 0.67
+        self.ratio = 1.5
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.1
+        self.copy_paste  = 1.0
+
+# RTCDet-L
+class RTCDet_Large_Config(RTCDetBaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.width = 1.0
+        self.depth = 1.0
+        self.ratio = 1.0
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.15
+        self.copy_paste  = 1.0
+
+# RTCDet-X
+class RTCDet_xLarge_Config(RTCDetBaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.width = 1.25
+        self.depth = 1.0
+        self.ratio = 1.0
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.2
+        self.copy_paste  = 1.0
+        

+ 4 - 0
yolo/models/__init__.py

@@ -11,6 +11,7 @@ from .yolov6.build     import build_yolov6
 from .yolov8.build     import build_yolov8
 from .yolov8_e2e.build import build_yolov8_e2e
 from .gelan.build      import build_gelan
+from .rtcdet.build     import build_rtcdet
 from .rtdetr.build     import build_rtdetr
 
 
@@ -44,6 +45,9 @@ def build_model(args, cfg, is_val=False):
     ## GElan
     elif 'gelan' in args.model:
         model, criterion = build_gelan(cfg, is_val)
+    ## RTCDet
+    elif 'rtcdet' in args.model:
+        model, criterion = build_rtcdet(cfg, is_val)
     ## RT-DETR
     elif 'rtdetr' in args.model:
         model, criterion = build_rtdetr(cfg, is_val)

+ 56 - 0
yolo/models/rtcdet/README.md

@@ -0,0 +1,56 @@
+# RTCDet: My Empirical Study of Real-Time Convolutional Object Detectors.
+
+- VOC
+
+|     Model   | Batch | Scale | AP<sup>val<br>0.5 | Weight |  Logs  |
+|-------------|-------|-------|-------------------|--------|--------|
+| RTCDet-S    | 1xb16 |  640  |               |  |  |
+
+- COCO
+
+|    Model    | Batch | Scale | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |  Logs  |
+|-------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|--------|
+| RTCDet-S    | 1xb16 |  640  |                    |               |   26.9            |   8.9             |  |  |
+
+
+
+## Train RTCDet
+### Single GPU
+Taking training RTCDet-S on COCO as the example,
+```Shell
+python train.py --cuda -d coco --root path/to/coco -m rtcdet_s -bs 16 --fp16 
+```
+
+### Multi GPU
+Taking training RTCDet-S on COCO as the example,
+```Shell
+python -m torch.distributed.run --nproc_per_node=8 train.py --cuda --distributed -d coco --root path/to/coco -m rtcdet_s -bs 256 --fp16 
+```
+
+## Test RTCDet
+Taking testing RTCDet-S on COCO-val as the example,
+```Shell
+python test.py --cuda -d coco --root path/to/coco -m rtcdet_s --weight path/to/RTCDet.pth --show 
+```
+
+## Evaluate RTCDet
+Taking evaluating RTCDet-S on COCO-val as the example,
+```Shell
+python eval.py --cuda -d coco --root path/to/coco -m rtcdet_s --weight path/to/RTCDet.pth 
+```
+
+## Demo
+### Detect with Image
+```Shell
+python demo.py --mode image --path_to_img path/to/image_dirs/ --cuda -m rtcdet_s --weight path/to/weight --show
+```
+
+### Detect with Video
+```Shell
+python demo.py --mode video --path_to_vid path/to/video --cuda -m rtcdet_s --weight path/to/weight --show --gif
+```
+
+### Detect with Camera
+```Shell
+python demo.py --mode camera --cuda -m rtcdet_s --weight path/to/weight --show --gif
+```

+ 18 - 0
yolo/models/rtcdet/build.py

@@ -0,0 +1,18 @@
+import torch.nn as nn
+
+from .loss import SetCriterion
+from .rtcdet import RTCDet
+
+
+# build object detector
+def build_rtcdet(cfg, is_val=False):
+    # -------------- Build YOLO --------------
+    model = RTCDet(cfg, is_val)
+
+    # -------------- Build criterion --------------
+    criterion = None
+    if is_val:
+        # build criterion for training
+        criterion = SetCriterion(cfg)
+        
+    return model, criterion

+ 197 - 0
yolo/models/rtcdet/loss.py

@@ -0,0 +1,197 @@
+import torch
+import torch.nn.functional as F
+
+from utils.box_ops import bbox2dist, bbox_iou
+from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
+
+from .matcher import TaskAlignedAssigner
+
+
+# ---------- Criterion for RTCDet ----------
+class SetCriterion(object):
+    def __init__(self, cfg):
+        # --------------- Basic parameters ---------------
+        self.cfg = cfg
+        self.reg_max = cfg.reg_max
+        self.num_classes   = cfg.num_classes
+        self.loss_cls_type = cfg.loss_cls_type
+        self.matcher_dict  = cfg.matcher_dict
+        # --------------- Loss config ---------------
+        self.loss_cls_weight = cfg.weight_dict["loss_cls"]
+        self.loss_box_weight = cfg.weight_dict["loss_box"]
+        self.loss_dfl_weight = cfg.weight_dict["loss_dfl"]
+        # --------------- Matcher config ---------------
+        self.matcher = TaskAlignedAssigner(num_classes     = cfg.num_classes,
+                                           topk_candidates = self.matcher_dict["topk_candidates"],
+                                           alpha           = self.matcher_dict["tal_alpha"],
+                                           beta            = self.matcher_dict["tal_beta"],
+                                           )
+
+    def loss_classes(self, pred_cls, gt_score):
+        # Compute VFL loss
+        if self.loss_cls_type == "vfl":
+            alpha, gamma = 0.75, 2.0
+            pred_sigmoid = pred_cls.sigmoid()
+            focal_weight = gt_score * (gt_score > 0.0).float() + \
+                alpha * (pred_sigmoid - gt_score).abs().pow(gamma) * \
+                (gt_score <= 0.0).float()
+            
+            loss_cls = F.binary_cross_entropy_with_logits(
+                pred_cls, gt_score, reduction='none') * focal_weight
+        # Compute BCE loss
+        else:
+            loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_score, reduction='none')
+
+        return loss_cls
+    
+    def loss_bboxes(self, pred_box, gt_box, bbox_weight):
+        # regression loss
+        ious = bbox_iou(pred_box, gt_box, xywh=False, CIoU=True)
+        loss_box = (1.0 - ious.squeeze(-1)) * bbox_weight
+
+        return loss_box
+    
+    def loss_dfl(self, pred_reg, gt_box, anchor, stride, bbox_weight=None):
+        # rescale coords by stride
+        gt_box_s = gt_box / stride
+        anchor_s = anchor / stride
+
+        # compute deltas
+        gt_ltrb_s = bbox2dist(anchor_s, gt_box_s, self.reg_max - 1)
+
+        gt_left = gt_ltrb_s.to(torch.long)
+        gt_right = gt_left + 1
+
+        weight_left = gt_right.to(torch.float) - gt_ltrb_s
+        weight_right = 1 - weight_left
+
+        # loss left
+        loss_left = F.cross_entropy(
+            pred_reg.view(-1, self.reg_max),
+            gt_left.view(-1),
+            reduction='none').view(gt_left.shape) * weight_left
+        # loss right
+        loss_right = F.cross_entropy(
+            pred_reg.view(-1, self.reg_max),
+            gt_right.view(-1),
+            reduction='none').view(gt_left.shape) * weight_right
+
+        loss_dfl = (loss_left + loss_right).mean(-1)
+        
+        if bbox_weight is not None:
+            loss_dfl *= bbox_weight
+
+        return loss_dfl
+
+    def __call__(self, outputs, targets):        
+        """
+            outputs['pred_cls']: List(Tensor) [B, M, C]
+            outputs['pred_reg']: List(Tensor) [B, M, 4*(reg_max+1)]
+            outputs['pred_box']: List(Tensor) [B, M, 4]
+            outputs['anchors']: List(Tensor) [M, 2]
+            outputs['strides']: List(Int) [8, 16, 32] output stride
+            outputs['stride_tensor']: List(Tensor) [M, 1]
+            targets: (List) [dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}, ...]
+        """
+        # preds: [B, M, C]
+        cls_preds = torch.cat(outputs['pred_cls'], dim=1)
+        reg_preds = torch.cat(outputs['pred_reg'], dim=1)
+        box_preds = torch.cat(outputs['pred_box'], dim=1)
+        delta_preds = torch.cat(outputs['pred_delta'], dim=1)
+        bs, num_anchors = cls_preds.shape[:2]
+        device = cls_preds.device
+        anchors = torch.cat(outputs['anchors'], dim=0)
+        strides = torch.cat(outputs['stride_tensor'], dim=0)
+
+        # --------------- label assignment ---------------
+        gt_score_targets = []
+        gt_bbox_targets = []
+        fg_masks = []
+        for batch_idx in range(bs):
+            tgt_labels = targets[batch_idx]["labels"].to(device)     # [Mp,]
+            tgt_boxs = targets[batch_idx]["boxes"].to(device)        # [Mp, 4]
+
+            if self.cfg.normalize_coords:
+                img_h, img_w = outputs['image_size']
+                tgt_boxs[..., [0, 2]] *= img_w
+                tgt_boxs[..., [1, 3]] *= img_h
+            
+            if self.cfg.box_format == 'xywh':
+                tgt_boxs_x1y1 = tgt_boxs[..., :2] - 0.5 * tgt_boxs[..., 2:]
+                tgt_boxs_x2y2 = tgt_boxs[..., :2] + 0.5 * tgt_boxs[..., 2:]
+                tgt_boxs = torch.cat([tgt_boxs_x1y1, tgt_boxs_x2y2], dim=-1)
+
+            # check target
+            if len(tgt_labels) == 0 or tgt_boxs.max().item() == 0.:
+                # There is no valid gt
+                fg_mask  = cls_preds.new_zeros(1, num_anchors).bool()               #[1, M,]
+                gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
+                gt_box   = cls_preds.new_zeros((1, num_anchors, 4))                  #[1, M, 4]
+            else:
+                tgt_labels = tgt_labels[None, :, None]      # [1, Mp, 1]
+                tgt_boxs = tgt_boxs[None]                   # [1, Mp, 4]
+                (
+                    _,
+                    gt_box,     # [1, M, 4]
+                    gt_score,   # [1, M, C]
+                    fg_mask,    # [1, M,]
+                    _
+                ) = self.matcher(
+                    pd_scores = cls_preds[batch_idx:batch_idx+1].detach().sigmoid(), 
+                    pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
+                    anc_points = anchors,
+                    gt_labels = tgt_labels,
+                    gt_bboxes = tgt_boxs
+                    )
+            gt_score_targets.append(gt_score)
+            gt_bbox_targets.append(gt_box)
+            fg_masks.append(fg_mask)
+
+        # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
+        fg_masks = torch.cat(fg_masks, 0).view(-1)                                    # [BM,]
+        gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes)  # [BM, C]
+        gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4)                   # [BM, 4]
+        num_fgs = gt_score_targets.sum()
+        
+        # Average loss normalizer across all the GPUs
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_fgs)
+        num_fgs = (num_fgs / get_world_size()).clamp(1.0)
+
+        # ------------------ Classification loss ------------------
+        cls_preds = cls_preds.view(-1, self.num_classes)
+        loss_cls = self.loss_classes(cls_preds, gt_score_targets)
+        loss_cls = loss_cls.sum() / num_fgs
+
+        # ------------------ Regression loss ------------------
+        box_preds_pos = box_preds.view(-1, 4)[fg_masks]
+        box_targets_pos = gt_bbox_targets.view(-1, 4)[fg_masks]
+        bbox_weight = gt_score_targets[fg_masks].sum(-1)
+        loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos, bbox_weight)
+        loss_box = loss_box.sum() / num_fgs
+
+        # ------------------ Distribution focal loss ------------------
+        reg_preds_pos = reg_preds.view(-1, 4*self.reg_max)[fg_masks]
+        anchors_pos = anchors[None].repeat(bs, 1, 1).view(-1, 2)[fg_masks]
+        stride_pos  = strides[None].repeat(bs, 1, 1).view(-1, 1)[fg_masks]
+        loss_dfl = self.loss_dfl(reg_preds_pos, box_targets_pos, anchors_pos, stride_pos, bbox_weight)
+        loss_dfl = loss_dfl.sum() / num_fgs
+
+        # Compute total loss
+        losses = loss_cls * self.loss_cls_weight + \
+                 loss_box * self.loss_box_weight + \
+                 loss_dfl * self.loss_dfl_weight 
+        loss_dict = dict(
+                loss_cls = loss_cls,
+                loss_box = loss_box,
+                loss_dfl = loss_dfl,
+                losses = losses
+        )
+
+        return loss_dict
+    
+
+if __name__ == "__main__":
+    pass

+ 202 - 0
yolo/models/rtcdet/matcher.py

@@ -0,0 +1,202 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from utils.box_ops import bbox_iou
+
+
+# -------------------------- Task Aligned Assigner --------------------------
+class TaskAlignedAssigner(nn.Module):
+    """
+        This code referenced to https://github.com/ultralytics/ultralytics
+    """
+    def __init__(self,
+                 num_classes     = 80,
+                 topk_candidates = 10,
+                 alpha           = 0.5,
+                 beta            = 6.0, 
+                 eps             = 1e-9):
+        super(TaskAlignedAssigner, self).__init__()
+        self.topk_candidates = topk_candidates
+        self.num_classes = num_classes
+        self.bg_idx = num_classes
+        self.alpha = alpha
+        self.beta = beta
+        self.eps = eps
+
+    @torch.no_grad()
+    def forward(self,
+                pd_scores,
+                pd_bboxes,
+                anc_points,
+                gt_labels,
+                gt_bboxes):
+        self.bs = pd_scores.size(0)
+        self.n_max_boxes = gt_bboxes.size(1)
+
+        mask_pos, align_metric, overlaps = self.get_pos_mask(
+            pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points)
+
+        target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(
+            mask_pos, overlaps, self.n_max_boxes)
+
+        # Assigned target
+        target_labels, target_bboxes, target_scores = self.get_targets(
+            gt_labels, gt_bboxes, target_gt_idx, fg_mask)
+
+        # normalize
+        align_metric *= mask_pos
+        pos_align_metrics = align_metric.amax(axis=-1, keepdim=True)  # b, max_num_obj
+        pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True)  # b, max_num_obj
+        norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
+        target_scores = target_scores * norm_align_metric
+
+        return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
+
+    def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points):
+        # get in_gts mask, (b, max_num_obj, h*w)
+        mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
+        # get anchor_align metric, (b, max_num_obj, h*w)
+        align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts)
+        # get topk_metric mask, (b, max_num_obj, h*w)
+        mask_topk = self.select_topk_candidates(align_metric)
+        # merge all mask to a final mask, (b, max_num_obj, h*w)
+        mask_pos = mask_topk * mask_in_gts
+
+        return mask_pos, align_metric, overlaps
+
+    def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts):
+        """Compute alignment metric given predicted and ground truth bounding boxes."""
+        na = pd_bboxes.shape[-2]
+        mask_in_gts = mask_in_gts.bool()  # b, max_num_obj, h*w
+        overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
+        bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
+
+        ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
+        ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes)  # b, max_num_obj
+        ind[1] = gt_labels.squeeze(-1)  # b, max_num_obj
+        # Get the scores of each grid for each gt cls
+        bbox_scores[mask_in_gts] = pd_scores[ind[0], :, ind[1]][mask_in_gts]  # b, max_num_obj, h*w
+
+        # (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
+        pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_in_gts]
+        gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_in_gts]
+        overlaps[mask_in_gts] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
+
+        align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
+        return align_metric, overlaps
+
+    def select_topk_candidates(self, metrics, largest=True):
+        """
+        Args:
+            metrics: (b, max_num_obj, h*w).
+            topk_mask: (b, max_num_obj, topk) or None
+        """
+        # (b, max_num_obj, topk)
+        topk_metrics, topk_idxs = torch.topk(metrics, self.topk_candidates, dim=-1, largest=largest)
+        topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
+        # (b, max_num_obj, topk)
+        topk_idxs.masked_fill_(~topk_mask, 0)
+
+        # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
+        count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
+        ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
+        for k in range(self.topk_candidates):
+            # Expand topk_idxs for each value of k and add 1 at the specified positions
+            count_tensor.scatter_add_(-1, topk_idxs[:, :, k:k + 1], ones)
+        # count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
+        # Filter invalid bboxes
+        count_tensor.masked_fill_(count_tensor > 1, 0)
+
+        return count_tensor.to(metrics.dtype)
+
+    def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
+        # Assigned target labels, (b, 1)
+        batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
+        target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes  # (b, h*w)
+        target_labels = gt_labels.long().flatten()[target_gt_idx]  # (b, h*w)
+
+        # Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
+        target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
+
+        # Assigned target scores
+        target_labels.clamp_(0)
+
+        # 10x faster than F.one_hot()
+        target_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.num_classes),
+                                    dtype=torch.int64,
+                                    device=target_labels.device)  # (b, h*w, 80)
+        target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
+
+        fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)  # (b, h*w, 80)
+        target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
+
+        return target_labels, target_bboxes, target_scores
+    
+
+# -------------------------- Basic Functions --------------------------
+def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
+    """select the positive anchors's center in gt
+    Args:
+        xy_centers (Tensor): shape(bs*n_max_boxes, num_total_anchors, 4)
+        gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
+    Return:
+        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    n_anchors = xy_centers.size(0)
+    bs, n_max_boxes, _ = gt_bboxes.size()
+    _gt_bboxes = gt_bboxes.reshape([-1, 4])
+    xy_centers = xy_centers.unsqueeze(0).repeat(bs * n_max_boxes, 1, 1)
+    gt_bboxes_lt = _gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, n_anchors, 1)
+    gt_bboxes_rb = _gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, n_anchors, 1)
+    b_lt = xy_centers - gt_bboxes_lt
+    b_rb = gt_bboxes_rb - xy_centers
+    bbox_deltas = torch.cat([b_lt, b_rb], dim=-1)
+    bbox_deltas = bbox_deltas.reshape([bs, n_max_boxes, n_anchors, -1])
+    return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)
+
+def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
+    """if an anchor box is assigned to multiple gts,
+        the one with the highest iou will be selected.
+    Args:
+        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+        overlaps (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    Return:
+        target_gt_idx (Tensor): shape(bs, num_total_anchors)
+        fg_mask (Tensor): shape(bs, num_total_anchors)
+        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    fg_mask = mask_pos.sum(-2)
+    if fg_mask.max() > 1:  # one anchor is assigned to multiple gt_bboxes
+        mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1)  # (b, n_max_boxes, h*w)
+        max_overlaps_idx = overlaps.argmax(1)  # (b, h*w)
+
+        is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
+        is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
+
+        mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()  # (b, n_max_boxes, h*w)
+        fg_mask = mask_pos.sum(-2)
+    # Find each grid serve which gt(index)
+    target_gt_idx = mask_pos.argmax(-2)  # (b, h*w)
+
+    return target_gt_idx, fg_mask, mask_pos
+
+def iou_calculator(box1, box2, eps=1e-9):
+    """Calculate iou for batch
+    Args:
+        box1 (Tensor): shape(bs, n_max_boxes, 1, 4)
+        box2 (Tensor): shape(bs, 1, num_total_anchors, 4)
+    Return:
+        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    box1 = box1.unsqueeze(2)  # [N, M1, 4] -> [N, M1, 1, 4]
+    box2 = box2.unsqueeze(1)  # [N, M2, 4] -> [N, 1, M2, 4]
+    px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
+    gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
+    x1y1 = torch.maximum(px1y1, gx1y1)
+    x2y2 = torch.minimum(px2y2, gx2y2)
+    overlap = (x2y2 - x1y1).clip(0).prod(-1)
+    area1 = (px2y2 - px1y1).clip(0).prod(-1)
+    area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
+    union = area1 + area2 - overlap + eps
+
+    return overlap / union

+ 148 - 0
yolo/models/rtcdet/rtcdet.py

@@ -0,0 +1,148 @@
+# --------------- Torch components ---------------
+import torch
+import torch.nn as nn
+
+# --------------- Model components ---------------
+from .rtcdet_backbone import RTCBackbone
+from .rtcdet_neck     import SPPF
+from .rtcdet_pafpn    import RTCPaFPN
+from .rtcdet_head     import MSDetHead
+from .rtcdet_pred     import MSDetPredLayer
+
+# --------------- External components ---------------
+from utils.misc import multiclass_nms
+
+
+# Real-time Convolutional Detector
+class RTCDet(nn.Module):
+    def __init__(self,
+                 cfg,
+                 is_val = False,
+                 ) -> None:
+        super(RTCDet, self).__init__()
+        # ---------------------- Basic setting ----------------------
+        self.cfg = cfg
+        self.num_classes = cfg.num_classes
+        ## Post-process parameters
+        self.topk_candidates = cfg.val_topk        if is_val else cfg.test_topk
+        self.conf_thresh     = cfg.val_conf_thresh if is_val else cfg.test_conf_thresh
+        self.nms_thresh      = cfg.val_nms_thresh  if is_val else cfg.test_nms_thresh
+        self.no_multi_labels = False if is_val else True
+        
+        # ---------------------- Network Parameters ----------------------
+        ## Backbone
+        self.backbone = RTCBackbone(cfg)
+        self.neck     = SPPF(cfg, self.backbone.pyramid_feat_dims[-1], self.backbone.pyramid_feat_dims[-1])
+        self.fpn      = RTCPaFPN(cfg, self.backbone.pyramid_feat_dims)
+        self.head     = MSDetHead(cfg, self.fpn.out_dims)
+        self.pred     = MSDetPredLayer(cfg, self.head.cls_head_dim, self.head.reg_head_dim)
+
+    def post_process(self, cls_preds, box_preds):
+        """
+        We process predictions at each scale hierarchically
+        Input:
+            cls_preds: List[torch.Tensor] -> [[B, M, C], ...], B=1
+            box_preds: List[torch.Tensor] -> [[B, M, 4], ...], B=1
+        Output:
+            bboxes: np.array -> [N, 4]
+            scores: np.array -> [N,]
+            labels: np.array -> [N,]
+        """
+        all_scores = []
+        all_labels = []
+        all_bboxes = []
+        
+        for cls_pred_i, box_pred_i in zip(cls_preds, box_preds):
+            cls_pred_i = cls_pred_i[0]
+            box_pred_i = box_pred_i[0]
+            if self.no_multi_labels:
+                # [M,]
+                scores, labels = torch.max(cls_pred_i.sigmoid(), dim=1)
+
+                # Keep top k top scoring indices only.
+                num_topk = min(self.topk_candidates, box_pred_i.size(0))
+
+                # topk candidates
+                predicted_prob, topk_idxs = scores.sort(descending=True)
+                topk_scores = predicted_prob[:num_topk]
+                topk_idxs = topk_idxs[:num_topk]
+
+                # filter out the proposals with low confidence score
+                keep_idxs = topk_scores > self.conf_thresh
+                scores = topk_scores[keep_idxs]
+                topk_idxs = topk_idxs[keep_idxs]
+
+                labels = labels[topk_idxs]
+                bboxes = box_pred_i[topk_idxs]
+            else:
+                # [M, C] -> [MC,]
+                scores_i = cls_pred_i.sigmoid().flatten()
+
+                # Keep top k top scoring indices only.
+                num_topk = min(self.topk_candidates, box_pred_i.size(0))
+
+                # torch.sort is actually faster than .topk (at least on GPUs)
+                predicted_prob, topk_idxs = scores_i.sort(descending=True)
+                topk_scores = predicted_prob[:num_topk]
+                topk_idxs = topk_idxs[:num_topk]
+
+                # filter out the proposals with low confidence score
+                keep_idxs = topk_scores > self.conf_thresh
+                scores = topk_scores[keep_idxs]
+                topk_idxs = topk_idxs[keep_idxs]
+
+                anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+                labels = topk_idxs % self.num_classes
+
+                bboxes = box_pred_i[anchor_idxs]
+
+            all_scores.append(scores)
+            all_labels.append(labels)
+            all_bboxes.append(bboxes)
+
+        scores = torch.cat(all_scores, dim=0)
+        labels = torch.cat(all_labels, dim=0)
+        bboxes = torch.cat(all_bboxes, dim=0)
+
+        # to cpu & numpy
+        scores = scores.cpu().numpy()
+        labels = labels.cpu().numpy()
+        bboxes = bboxes.cpu().numpy()
+
+        # nms
+        scores, labels, bboxes = multiclass_nms(
+            scores, labels, bboxes, self.nms_thresh, self.num_classes)
+        
+        return bboxes, scores, labels
+    
+    def forward(self, x):
+        # ---------------- Backbone ----------------
+        pyramid_feats = self.backbone(x)
+        
+        # ---------------- Neck: SPP ----------------
+        pyramid_feats[-1] = self.neck(pyramid_feats[-1])
+
+        # ---------------- Neck: PaFPN ----------------
+        pyramid_feats = self.fpn(pyramid_feats)
+
+        # ---------------- Heads ----------------
+        cls_feats, reg_feats = self.head(pyramid_feats)
+
+        # ---------------- Preds ----------------
+        outputs = self.pred(cls_feats, reg_feats)
+        outputs['image_size'] = [x.shape[2], x.shape[3]]
+
+        if not self.training:
+            all_cls_preds = outputs['pred_cls']
+            all_box_preds = outputs['pred_box']
+
+            # post process
+            bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
+            outputs = {
+                "scores": scores,
+                "labels": labels,
+                "bboxes": bboxes
+            }
+        
+        return outputs
+    

+ 135 - 0
yolo/models/rtcdet/rtcdet_backbone.py

@@ -0,0 +1,135 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .rtcdet_basic import BasicConv, ElanLayer, MDown, ADown
+except:
+    from  rtcdet_basic import BasicConv, ElanLayer, MDown, ADown
+
+
+# ------------------ Basic functions ------------------
+class RTCBackbone(nn.Module):
+    def __init__(self, cfg):
+        super(RTCBackbone, self).__init__()
+        # ------------------ Basic setting ------------------
+        self.stage_depth = [round(nb  * cfg.depth) for nb  in cfg.stage_depth]
+        self.stage_dims  = [round(dim * cfg.width * cfg.ratio) if i == len(cfg.stage_dims) - 1
+                            else round(dim * cfg.width) for i, dim in enumerate(cfg.stage_dims)]
+        self.pyramid_feat_dims = self.stage_dims[-3:]
+        
+        # ------------------ Model setting ------------------
+        ## P1/2
+        self.layer_1 = BasicConv(3, self.stage_dims[0], kernel_size=6, padding=2, stride=2,
+                                 act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise)
+        # P2/4
+        self.layer_2 = nn.Sequential(
+            self.make_downsample_block(cfg, self.stage_dims[0], self.stage_dims[1]),
+            self.make_stage_block(cfg, self.stage_dims[1], self.stage_dims[1], self.stage_depth[0])
+        )
+        # P3/8
+        self.layer_3 = nn.Sequential(
+            self.make_downsample_block(cfg, self.stage_dims[1], self.stage_dims[2]),
+            self.make_stage_block(cfg, self.stage_dims[2], self.stage_dims[2], self.stage_depth[1])
+        )
+        # P4/16
+        self.layer_4 = nn.Sequential(
+            self.make_downsample_block(cfg, self.stage_dims[2], self.stage_dims[3]),
+            self.make_stage_block(cfg, self.stage_dims[3], self.stage_dims[3], self.stage_depth[2])
+        )
+        # P5/32
+        self.layer_5 = nn.Sequential(
+            self.make_downsample_block(cfg, self.stage_dims[3], self.stage_dims[4]),
+            self.make_stage_block(cfg, self.stage_dims[4], self.stage_dims[4], self.stage_depth[3])
+        )
+
+        # Initialize all layers
+        self.init_weights()
+                        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                m.reset_parameters()
+
+    def make_downsample_block(self, cfg, in_dim, out_dim):
+        if cfg.bk_ds_block == "conv":
+            return BasicConv(in_dim, out_dim, kernel_size=3, padding=1, stride=2,
+                             act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise)
+        if cfg.bk_ds_block == "mdown":
+            return MDown(in_dim, out_dim, cfg.bk_act, cfg.bk_norm, cfg.bk_depthwise)
+        if cfg.bk_ds_block == "adown":
+            return ADown(in_dim, out_dim, cfg.bk_act, cfg.bk_norm, cfg.bk_depthwise)
+        if cfg.bk_ds_block == "maxpool":
+            assert in_dim == out_dim
+            return nn.MaxPool2d((2, 2), stride=2)
+        else:
+            raise NotImplementedError("Unknown fpn downsample block: {}".format(cfg.fpn_ds_block))
+        
+    def make_stage_block(self, cfg, in_dim, out_dim, stage_depth):
+        if cfg.bk_block == "elan_layer":
+            return ElanLayer(in_dim     = in_dim,
+                             out_dim    = out_dim,
+                             num_blocks = stage_depth,
+                             expansion  = 0.5,
+                             shortcut   = True,
+                             act_type   = cfg.bk_act,
+                             norm_type  = cfg.bk_norm,
+                             depthwise  = cfg.bk_depthwise)
+        else:
+            raise NotImplementedError("Unknown stage block: {}".format(cfg.bk_block))
+        
+    def forward(self, x):
+        c1 = self.layer_1(x)
+        c2 = self.layer_2(c1)
+        c3 = self.layer_3(c2)
+        c4 = self.layer_4(c3)
+        c5 = self.layer_5(c4)
+        outputs = [c3, c4, c5]
+
+        return outputs
+
+
+# ------------------ Functions ------------------
+## build Yolo's Backbone
+def build_backbone(cfg): 
+    # model
+    backbone = RTCBackbone(cfg)
+        
+    return backbone
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    class BaseConfig(object):
+        def __init__(self) -> None:
+            self.stage_dims =  [64, 128, 256, 512, 512]
+            self.stage_depth = [3, 6, 6, 3]
+            self.bk_block = "elan_layer"
+            self.bk_ds_block = "mdown"
+            self.bk_act = 'silu'
+            self.bk_norm = 'bn'
+            self.bk_depthwise = False
+            self.use_pretrained = False
+            self.width = 0.5
+            self.depth = 0.34
+            self.ratio = 2.0
+
+    cfg = BaseConfig()
+    model = build_backbone(cfg).cuda()
+    x = torch.randn(1, 3, 640, 640).cuda()
+
+    for _ in range(5):
+        t0 = time.time()
+        outputs = model(x)
+        t1 = time.time()
+        print('Time: ', t1 - t0)
+        
+    for out in outputs:
+        print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 313 - 0
yolo/models/rtcdet/rtcdet_basic.py

@@ -0,0 +1,313 @@
+import torch
+import torch.nn as nn
+from typing import List
+
+
+# --------------------- Basic modules ---------------------
+def get_conv2d(c1, c2, k, p, s, d=1, g=1, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+        
+def get_norm(norm_type, dim):
+    if norm_type == 'bn':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'gn':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+    elif norm_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+
+class BasicConv(nn.Module):
+    def __init__(self, 
+                 in_dim,                   # in channels
+                 out_dim,                  # out channels 
+                 kernel_size=1,            # kernel size 
+                 padding=0,                # padding
+                 stride=1,                 # padding
+                 dilation=1,               # dilation
+                 groups=1,                 # group
+                 act_type  :str = 'lrelu', # activation
+                 norm_type :str = 'bn',    # normalization
+                 depthwise :bool = False
+                ):
+        super(BasicConv, self).__init__()
+        self.depthwise = depthwise
+        use_bias = False if norm_type is not None else True
+        if not depthwise:
+            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=groups, bias=use_bias)
+            self.norm = get_norm(norm_type, out_dim)
+        else:
+            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim, bias=use_bias)
+            self.norm1 = get_norm(norm_type, in_dim)
+            self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1, bias=use_bias)
+            self.norm2 = get_norm(norm_type, out_dim)
+        self.act  = get_activation(act_type)
+
+    def forward(self, x):
+        if not self.depthwise:
+            return self.act(self.norm(self.conv(x)))
+        else:
+            # Depthwise conv
+            x = self.act(self.norm1(self.conv1(x)))
+            # Pointwise conv
+            x = self.act(self.norm2(self.conv2(x)))
+            return x
+
+class DWConv(nn.Module):
+    def __init__(self, 
+                 in_dim      :int,           # in channels
+                 out_dim     :int,           # out channels 
+                 kernel_size :int = 1,       # kernel size 
+                 padding     :int = 0,       # padding
+                 stride      :int = 1,       # padding
+                 dilation    :int = 1,       # dilation
+                 act_type    :str = 'lrelu', # activation
+                 norm_type   :str = 'BN',    # normalization
+                ):
+        super(DWConv, self).__init__()
+        assert in_dim == out_dim
+        use_bias = False if norm_type is not None else True
+        self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=out_dim, bias=use_bias)
+        self.norm = get_norm(norm_type, out_dim)
+        self.act  = get_activation(act_type)
+
+    def forward(self, x):
+        return self.act(self.norm(self.conv(x)))
+
+
+# --------------------- Downsample modules ---------------------
+class ADown(nn.Module):
+    def __init__(self,
+                 in_dim    :int,
+                 out_dim   :int,
+                 act_type  :str  = "silu",
+                 norm_type :str  = "bn",
+                 depthwise :bool = False):
+        super().__init__()
+        inter_dim = out_dim // 2
+        self.conv_layer_1 = BasicConv(in_dim // 2, inter_dim, kernel_size=3, padding=1, stride=2,
+                                      act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.conv_layer_2 = BasicConv(in_dim // 2, inter_dim, kernel_size=1,
+                                      act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+    def forward(self, x):
+        # Split
+        x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
+        x1,x2 = x.chunk(2, 1)
+
+        # Downsample branch - 1
+        x1 = self.conv_layer_1(x1)
+
+        # Downsample branch - 2
+        x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1)
+        x2 = self.conv_layer_2(x2)
+
+        return torch.cat([x1, x2], dim=1)
+
+class MDown(nn.Module):
+    def __init__(self,
+                 in_dim    :int,
+                 out_dim   :int,
+                 act_type  :str   = 'silu',
+                 norm_type :str   = 'BN',
+                 depthwise :bool  = False,
+                 ) -> None:
+        super().__init__()
+        inter_dim = out_dim // 2
+        self.downsample_1 = nn.Sequential(
+            nn.MaxPool2d((2, 2), stride=2),
+            BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        )
+        self.downsample_2 = nn.Sequential(
+            BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type),
+            BasicConv(inter_dim, inter_dim,
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+
+    def forward(self, x):
+        x1 = self.downsample_1(x)
+        x2 = self.downsample_2(x)
+
+        return torch.cat([x1, x2], dim=1)
+
+
+# --------------------- Feature processing modules ---------------------
+class MBottleneck(nn.Module):
+    def __init__(self,
+                 in_dim    :int,
+                 out_dim   :int,
+                 expansion :float = 0.5,
+                 shortcut  :bool  = False,
+                 act_type  :str   = 'silu',
+                 norm_type :str   = 'bn',
+                 depthwise :bool  = False,
+                 ) -> None:
+        super(MBottleneck, self).__init__()
+        inter_dim = int(out_dim * expansion)
+        # ----------------- Network setting -----------------
+        self.conv_layer = nn.Sequential(
+            # 3x3 conv + bn + silu
+            BasicConv(in_dim, inter_dim, kernel_size=3, padding=1, stride=1,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            # 5x5 dw conv
+            DWConv(inter_dim, inter_dim, kernel_size=5, padding=2, stride=1,
+                   act_type=None, norm_type=norm_type),
+            # 3x3 conv + bn + silu
+            BasicConv(inter_dim, out_dim, kernel_size=3, padding=1, stride=1,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+        )
+        self.shortcut = shortcut and in_dim == out_dim
+
+    def forward(self, x):
+        h = self.conv_layer(x)
+
+        return x + h if self.shortcut else h
+
+class CSPLayer(nn.Module):
+    # CSP Bottleneck
+    def __init__(self,
+                 in_dim      :int,
+                 out_dim     :int,
+                 num_blocks  :int   = 1,
+                 expansion   :float = 0.5,
+                 shortcut    :bool  = True,
+                 act_type    :str   = 'silu',
+                 norm_type   :str   = 'bn',
+                 depthwise   :bool  = False,
+                 ) -> None:
+        super().__init__()
+        inter_dim = round(out_dim * expansion)
+        self.input_proj = BasicConv(in_dim, out_dim, kernel_size=1, act_type=None, norm_type=norm_type, depthwise=depthwise)
+        self.module = nn.Sequential(*[MBottleneck(inter_dim,
+                                                  inter_dim,
+                                                  expansion   = 1.0,
+                                                  shortcut    = shortcut,
+                                                  act_type    = act_type,
+                                                  norm_type   = norm_type,
+                                                  depthwise   = depthwise,
+                                                  ) for _ in range(num_blocks)])
+
+    def forward(self, x):
+        # Split
+        x1, x2 = torch.chunk(self.input_proj(x), chunks=2, dim=1)
+
+        # Branch
+        x2 = self.module(x2)
+
+        # Output proj
+        out = torch.cat([x1, x2], dim=1)
+
+        return out
+
+class ElanLayer(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 expansion  :float = 0.5,
+                 num_blocks :int   = 1,
+                 shortcut   :bool  = False,
+                 act_type   :str   = 'silu',
+                 norm_type  :str   = 'bn',
+                 depthwise  :bool  = False,
+                 ) -> None:
+        super(ElanLayer, self).__init__()
+        inter_dim = round(out_dim * expansion)
+        self.input_proj  = BasicConv(in_dim, inter_dim * 2, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.output_proj = BasicConv((2 + num_blocks) * inter_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.module      = nn.ModuleList([MBottleneck(inter_dim,
+                                                      inter_dim,
+                                                      expansion   = 1.0,
+                                                      shortcut    = shortcut,
+                                                      act_type    = act_type,
+                                                      norm_type   = norm_type,
+                                                      depthwise   = depthwise)
+                                                      for _ in range(num_blocks)])
+
+    def forward(self, x):
+        # Input proj
+        x1, x2 = torch.chunk(self.input_proj(x), 2, dim=1)
+        out = list([x1, x2])
+
+        # Bottleneck
+        out.extend(m(out[-1]) for m in self.module)
+
+        # Output proj
+        out = self.output_proj(torch.cat(out, dim=1))
+
+        return out
+    
+class GElanLayer(nn.Module):
+    """Modified YOLOv9's GELAN module"""
+    def __init__(self,
+                 in_dim     :int,
+                 inter_dims :List,
+                 out_dim    :int,
+                 num_blocks :int   = 1,
+                 shortcut   :bool  = False,
+                 act_type   :str   = 'silu',
+                 norm_type  :str   = 'bn',
+                 depthwise  :bool  = False,
+                 ) -> None:
+        super(GElanLayer, self).__init__()
+        # ----------- Basic parameters -----------
+        self.in_dim = in_dim
+        self.inter_dims = inter_dims
+        self.out_dim = out_dim
+
+        # ----------- Network parameters -----------
+        self.conv_layer_1  = BasicConv(in_dim, inter_dims[0], kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.elan_module_1 = nn.Sequential(
+             CSPLayer(inter_dims[0]//2,
+                      inter_dims[1],
+                      num_blocks  = num_blocks,
+                      shortcut    = shortcut,
+                      expansion   = 0.5,
+                      act_type    = act_type,
+                      norm_type   = norm_type,
+                      depthwise   = depthwise),
+            BasicConv(inter_dims[1], inter_dims[1], kernel_size=3, padding=1,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        self.elan_module_2 = nn.Sequential(
+             CSPLayer(inter_dims[1],
+                      inter_dims[1],
+                      num_blocks  = num_blocks,
+                      shortcut    = shortcut,
+                      expansion   = 0.5,
+                      act_type    = act_type,
+                      norm_type   = norm_type,
+                      depthwise   = depthwise),
+            BasicConv(inter_dims[1], inter_dims[1], kernel_size=3, padding=1,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        self.conv_layer_2 = BasicConv(inter_dims[0] + 2*self.inter_dims[1], out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+
+    def forward(self, x):
+        # Input proj
+        x1, x2 = torch.chunk(self.conv_layer_1(x), 2, dim=1)
+        out = list([x1, x2])
+
+        # ELAN module
+        out.append(self.elan_module_1(out[-1]))
+        out.append(self.elan_module_2(out[-1]))
+
+        # Output proj
+        out = self.conv_layer_2(torch.cat(out, dim=1))
+
+        return out

+ 280 - 0
yolo/models/rtcdet/rtcdet_head.py

@@ -0,0 +1,280 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .rtcdet_basic import BasicConv
+except:
+    from  rtcdet_basic import BasicConv
+    
+
+# -------------------- Detection Head --------------------
+## Single-level Detection Head
+class DetHead(nn.Module):
+    def __init__(self,
+                 in_dim       :int  = 256,
+                 cls_head_dim :int  = 256,
+                 reg_head_dim :int  = 256,
+                 num_cls_head :int  = 2,
+                 num_reg_head :int  = 2,
+                 act_type     :str  = "silu",
+                 norm_type    :str  = "BN",
+                 depthwise    :bool = False):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.in_dim = in_dim
+        self.num_cls_head = num_cls_head
+        self.num_reg_head = num_reg_head
+        self.act_type = act_type
+        self.norm_type = norm_type
+        self.depthwise = depthwise
+        
+        # --------- Network Parameters ----------
+        ## cls head
+        cls_feats = []
+        self.cls_head_dim = cls_head_dim
+        for i in range(num_cls_head):
+            if i == 0:
+                cls_feats.append(
+                    BasicConv(in_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+            else:
+                cls_feats.append(
+                    BasicConv(self.cls_head_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+        ## reg head
+        reg_feats = []
+        self.reg_head_dim = reg_head_dim
+        for i in range(num_reg_head):
+            if i == 0:
+                reg_feats.append(
+                    BasicConv(in_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, groups=4, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+            else:
+                reg_feats.append(
+                    BasicConv(self.reg_head_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, groups=4,
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+        self.cls_feats = nn.Sequential(*cls_feats)
+        self.reg_feats = nn.Sequential(*reg_feats)
+
+        self.init_weights()
+        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
+
+    def forward(self, x):
+        """
+            in_feats: (Tensor) [B, C, H, W]
+        """
+        cls_feats = self.cls_feats(x)
+        reg_feats = self.reg_feats(x)
+
+        return cls_feats, reg_feats
+    
+## Multi-scales Detection Head
+class MSDetHead(nn.Module):
+    def __init__(self, cfg, in_dims):
+        super().__init__()
+        ## ----------- Network Parameters -----------
+        self.multi_level_heads = nn.ModuleList(
+            [DetHead(in_dim       = in_dims[level],
+                     cls_head_dim = max(in_dims[0], min(cfg.num_classes, 128)),
+                     reg_head_dim = max(in_dims[0]//4, 16, 4*cfg.reg_max),
+                     num_cls_head = cfg.num_cls_head,
+                     num_reg_head = cfg.num_reg_head,
+                     act_type     = cfg.head_act,
+                     norm_type    = cfg.head_norm,
+                     depthwise    = cfg.head_depthwise)
+                     for level in range(cfg.num_levels)
+                     ])
+        # --------- Basic Parameters ----------
+        self.in_dims = in_dims
+        self.cls_head_dim = self.multi_level_heads[0].cls_head_dim
+        self.reg_head_dim = self.multi_level_heads[0].reg_head_dim
+
+
+    def forward(self, feats):
+        """
+            feats: List[(Tensor)] [[B, C, H, W], ...]
+        """
+        cls_feats = []
+        reg_feats = []
+        for feat, head in zip(feats, self.multi_level_heads):
+            # ---------------- Pred ----------------
+            cls_feat, reg_feat = head(feat)
+
+            cls_feats.append(cls_feat)
+            reg_feats.append(reg_feat)
+
+        return cls_feats, reg_feats
+
+
+# -------------------- Segmentation Head --------------------
+## Single-level Segmentation Head
+class SegHead(nn.Module):
+    def __init__(self,
+                 in_dim       :int  = 256,
+                 cls_head_dim :int  = 256,
+                 reg_head_dim :int  = 256,
+                 seg_head_dim :int  = 256,
+                 num_cls_head :int  = 2,
+                 num_reg_head :int  = 2,
+                 num_seg_head :int  = 2,
+                 act_type     :str  = "silu",
+                 norm_type    :str  = "BN",
+                 depthwise    :bool = False):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.in_dim = in_dim
+        self.num_cls_head = num_cls_head
+        self.num_reg_head = num_reg_head
+        self.num_seg_head = num_reg_head
+        self.act_type = act_type
+        self.norm_type = norm_type
+        self.depthwise = depthwise
+        
+        # --------- Network Parameters ----------
+        ## cls head
+        cls_feats = []
+        self.cls_head_dim = cls_head_dim
+        for i in range(num_cls_head):
+            if i == 0:
+                cls_feats.append(
+                    BasicConv(in_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+            else:
+                cls_feats.append(
+                    BasicConv(self.cls_head_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+        ## reg head
+        reg_feats = []
+        self.reg_head_dim = reg_head_dim
+        for i in range(num_reg_head):
+            if i == 0:
+                reg_feats.append(
+                    BasicConv(in_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+            else:
+                reg_feats.append(
+                    BasicConv(self.reg_head_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+        ## seg head
+        seg_feats = []
+        self.seg_head_dim = seg_head_dim
+        for i in range(num_seg_head):
+            if i == 0:
+                seg_feats.append(
+                    BasicConv(in_dim, self.seg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+            else:
+                seg_feats.append(
+                    BasicConv(self.seg_head_dim, self.seg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+        self.cls_feats = nn.Sequential(*cls_feats)
+        self.reg_feats = nn.Sequential(*reg_feats)
+        self.seg_feats = nn.Sequential(*seg_feats)
+
+        self.init_weights()
+        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
+
+    def forward(self, x):
+        """
+            in_feats: (Tensor) [B, C, H, W]
+        """
+        cls_feats = self.cls_feats(x)
+        reg_feats = self.reg_feats(x)
+        seg_feats = self.reg_feats(x)
+
+        return cls_feats, reg_feats, seg_feats
+    
+## Multi-scales Segmentation Head
+class MSSegHead(nn.Module):
+    def __init__(self, cfg, in_dims):
+        super().__init__()
+        ## ----------- Network Parameters -----------
+        self.multi_level_heads = nn.ModuleList(
+            [SegHead(in_dim       = in_dims[level],
+                     cls_head_dim = max(in_dims[0], min(cfg.num_classes, 128)),
+                     reg_head_dim = max(in_dims[0]//4, 16, 4*cfg.reg_max),
+                     seg_head_dim = in_dims[0],
+                     num_cls_head = cfg.num_cls_head,
+                     num_reg_head = cfg.num_reg_head,
+                     num_seg_head = cfg.num_seg_head,
+                     act_type     = cfg.head_act,
+                     norm_type    = cfg.head_norm,
+                     depthwise    = cfg.head_depthwise)
+                     for level in range(cfg.num_levels)
+                     ])
+        # --------- Basic Parameters ----------
+        self.in_dims = in_dims
+        self.cls_head_dim = self.multi_level_heads[0].cls_head_dim
+        self.reg_head_dim = self.multi_level_heads[0].reg_head_dim
+        self.seg_head_dim = self.multi_level_heads[0].seg_head_dim
+
+    def forward(self, feats):
+        """
+            feats: List[(Tensor)] [[B, C, H, W], ...]
+        """
+        cls_feats = []
+        reg_feats = []
+        seg_feats = []
+        for feat, head in zip(feats, self.multi_level_heads):
+            # ---------------- Pred ----------------
+            cls_feat, reg_feat, seg_feat = head(feat)
+
+            cls_feats.append(cls_feat)
+            reg_feats.append(reg_feat)
+            seg_feats.append(seg_feat)
+
+        return cls_feats, reg_feats, seg_feats

+ 39 - 0
yolo/models/rtcdet/rtcdet_neck.py

@@ -0,0 +1,39 @@
+import torch
+import torch.nn as nn
+
+from .rtcdet_basic import BasicConv
+
+
+# -------------- Neck network --------------
+class SPPF(nn.Module):
+    """
+        This code referenced to https://github.com/ultralytics/yolov5
+    """
+    def __init__(self, cfg, in_dim, out_dim):
+        super().__init__()
+        ## ----------- Basic Parameters -----------
+        inter_dim = round(in_dim * cfg.neck_expand_ratio)
+        self.out_dim = out_dim
+        ## ----------- Network Parameters -----------
+        self.input_proj  = BasicConv(in_dim, inter_dim, kernel_size=1,
+                                     act_type=cfg.neck_act, norm_type=cfg.neck_norm)
+        self.output_proj = BasicConv(inter_dim * 4, out_dim, kernel_size=1,
+                                     act_type=cfg.neck_act, norm_type=cfg.neck_norm)
+        self.module = nn.MaxPool2d(cfg.spp_pooling_size, stride=1, padding=cfg.spp_pooling_size//2)
+
+        # Initialize all layers
+        self.init_weights()
+                
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                m.reset_parameters()
+
+    def forward(self, x):
+        x = self.input_proj(x)
+        y1 = self.module(x)
+        y2 = self.module(y1)
+
+        return self.output_proj(torch.cat((x, y1, y2, self.module(y2)), 1))
+    

+ 108 - 0
yolo/models/rtcdet/rtcdet_pafpn.py

@@ -0,0 +1,108 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import List
+
+try:
+    from .rtcdet_basic import BasicConv, DWConv, ElanLayer, MDown, ADown
+except:
+    from  rtcdet_basic import BasicConv, DWConv, ElanLayer, MDown, ADown
+
+
+# -------------- Feature pyramid network --------------
+class RTCPaFPN(nn.Module):
+    def __init__(self,
+                 cfg,
+                 in_dims :List = [256, 512, 1024],
+                 ) -> None:
+        super(RTCPaFPN, self).__init__()
+        print('==============================')
+        print('FPN: {}'.format("RTC-PaFPN"))
+        # ----------- Basic Parameters -----------
+        self.in_dims = in_dims[::-1]
+
+        # ----------- Yolov8's Top-down FPN -----------
+        ## P5 -> P4
+        self.top_down_layer_1 = self.make_fpn_block(cfg, self.in_dims[0] + self.in_dims[1], round(512*cfg.width), round(3 * cfg.depth))
+        ## P4 -> P3
+        self.top_down_layer_2 = self.make_fpn_block(cfg, self.in_dims[2] + round(512*cfg.width), round(256*cfg.width), round(3 * cfg.depth))
+
+        # ----------- Yolov8's Bottom-up PAN -----------
+        ## P3 -> P4
+        self.dowmsample_layer_1 = self.make_downsample_block(cfg, round(256*cfg.width), round(256*cfg.width))
+        self.bottom_up_layer_1  = self.make_fpn_block(cfg, round(256*cfg.width) + round(512*cfg.width), round(512*cfg.width), round(3 * cfg.depth))
+        ## P4 -> P5
+        self.dowmsample_layer_2 = self.make_downsample_block(cfg, round(512*cfg.width), round(512*cfg.width))
+        self.bottom_up_layer_2  = self.make_fpn_block(cfg, round(512*cfg.width) + self.in_dims[0], round(512*cfg.width*cfg.ratio), round(3 * cfg.depth))
+
+        # ----------- Output projection -----------
+        self.out_layers = nn.ModuleList([
+            BasicConv(in_dim, round(256*cfg.width), kernel_size=1, act_type=cfg.fpn_act, norm_type=cfg.fpn_norm)
+                      for in_dim in [round(256*cfg.width), round(512*cfg.width), round(512*cfg.width*cfg.ratio)]])
+        self.out_dims = [round(256*cfg.width)] * 3
+
+        self.init_weights()
+        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                m.reset_parameters()
+
+    def make_downsample_block(self, cfg, in_dim, out_dim):
+        if cfg.fpn_ds_block == "conv":
+            return BasicConv(in_dim, out_dim, kernel_size=3, padding=1, stride=2,
+                             act_type=cfg.fpn_act, norm_type=cfg.fpn_norm, depthwise=cfg.fpn_depthwise)
+        if cfg.fpn_ds_block == "dw_conv":
+            return DWConv(in_dim, out_dim, kernel_size=3, padding=1, stride=2,
+                             act_type=cfg.fpn_act, norm_type=cfg.fpn_norm)
+        if cfg.fpn_ds_block == "mdown":
+            return MDown(in_dim, out_dim, cfg.bk_act, cfg.bk_norm, cfg.bk_depthwise)
+        if cfg.fpn_ds_block == "adown":
+            return ADown(in_dim, out_dim, cfg.bk_act, cfg.bk_norm, cfg.bk_depthwise)
+        else:
+            raise NotImplementedError("Unknown fpn downsample block: {}".format(cfg.fpn_ds_block))
+        
+    def make_fpn_block(self, cfg, in_dim, out_dim, block_depth):
+        if cfg.fpn_block == "elan_layer":
+            return ElanLayer(in_dim     = in_dim,
+                             out_dim    = out_dim,
+                             num_blocks = block_depth,
+                             expansion  = 0.5,
+                             shortcut   = False,
+                             act_type   = cfg.fpn_act,
+                             norm_type  = cfg.fpn_norm,
+                             depthwise  = cfg.fpn_depthwise)
+        else:
+            raise NotImplementedError("Unknown stage block: {}".format(cfg.bk_block))
+        
+    def forward(self, features):
+        c3, c4, c5 = features
+
+        # ------------------ Top down FPN ------------------
+        ## P5 -> P4
+        p5_up = F.interpolate(c5, scale_factor=2.0)
+        p4 = self.top_down_layer_1(torch.cat([p5_up, c4], dim=1))
+
+        ## P4 -> P3
+        p4_up = F.interpolate(p4, scale_factor=2.0)
+        p3 = self.top_down_layer_2(torch.cat([p4_up, c3], dim=1))
+
+        # ------------------ Bottom up FPN ------------------
+        ## p3 -> P4
+        p3_ds = self.dowmsample_layer_1(p3)
+        p4 = self.bottom_up_layer_1(torch.cat([p3_ds, p4], dim=1))
+
+        ## P4 -> 5
+        p4_ds = self.dowmsample_layer_2(p4)
+        p5 = self.bottom_up_layer_2(torch.cat([p4_ds, c5], dim=1))
+
+        out_feats = [p3, p4, p5] # [P3, P4, P5]
+                
+        # ------------------ Output projection ------------------
+        out_feats_proj = []
+        for feat, layer in zip(out_feats, self.out_layers):
+            out_feats_proj.append(layer(feat))
+            
+        return out_feats_proj
+    

+ 330 - 0
yolo/models/rtcdet/rtcdet_pred.py

@@ -0,0 +1,330 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+# -------------------- Detection Pred Layer --------------------
+## Single-level pred layer
+class DetPredLayer(nn.Module):
+    def __init__(self,
+                 cls_dim     :int = 256,
+                 reg_dim     :int = 256,
+                 stride      :int = 32,
+                 num_classes :int = 80,
+                 num_coords  :int = 4):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.stride = stride
+        self.cls_dim = cls_dim
+        self.reg_dim = reg_dim
+        self.num_classes = num_classes
+        self.num_coords = num_coords
+
+        # --------- Network Parameters ----------
+        self.cls_pred = nn.Conv2d(cls_dim, num_classes, kernel_size=1)
+        self.reg_pred = nn.Conv2d(reg_dim, num_coords,  kernel_size=1, groups=4)                
+
+        self.init_bias()
+        
+    def init_bias(self):
+        # cls pred bias
+        b = self.cls_pred.bias.view(1, -1)
+        b.data.fill_(math.log(5 / self.num_classes / (640. / self.stride) ** 2))
+        self.cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        # reg pred bias
+        b = self.reg_pred.bias.view(-1, )
+        b.data.fill_(1.0)
+        self.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        w = self.reg_pred.weight
+        w.data.fill_(0.)
+        self.reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
+
+    def generate_anchors(self, fmp_size):
+        """
+            fmp_size: (List) [H, W]
+        """
+        # generate grid cells
+        fmp_h, fmp_w = fmp_size
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        # [H, W, 2] -> [HW, 2]
+        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
+        anchors += 0.5  # add center offset
+        anchors *= self.stride
+
+        return anchors
+        
+    def forward(self, cls_feat, reg_feat):
+        # pred
+        cls_pred = self.cls_pred(cls_feat)
+        reg_pred = self.reg_pred(reg_feat)
+
+        # generate anchor boxes: [M, 4]
+        B, _, H, W = cls_pred.size()
+        fmp_size = [H, W]
+        anchors = self.generate_anchors(fmp_size)
+        anchors = anchors.to(cls_pred.device)
+        # stride tensor: [M, 1]
+        stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride
+        
+        # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
+        cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+        reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_coords)
+        
+        # output dict
+        outputs = {"pred_cls": cls_pred,            # List(Tensor) [B, M, C]
+                   "pred_reg": reg_pred,            # List(Tensor) [B, M, 4*(reg_max)]
+                   "anchors": anchors,              # List(Tensor) [M, 2]
+                   "strides": self.stride,          # List(Int) = [8, 16, 32]
+                   "stride_tensor": stride_tensor   # List(Tensor) [M, 1]
+                   }
+
+        return outputs
+
+## Multi-scales pred layer
+class MSDetPredLayer(nn.Module):
+    def __init__(self,
+                 cfg,
+                 cls_dim,
+                 reg_dim,
+                 ):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.cfg = cfg
+        self.cls_dim = cls_dim
+        self.reg_dim = reg_dim
+        self.reg_max    = cfg.reg_max
+        self.num_levels = cfg.num_levels
+        self.out_stride = cfg.out_stride
+
+        # ----------- Network Parameters -----------
+        ## pred layers
+        self.multi_level_preds = nn.ModuleList(
+            [DetPredLayer(cls_dim     = cls_dim,
+                          reg_dim     = reg_dim,
+                          stride      = cfg.out_stride[level],
+                          num_classes = cfg.num_classes,
+                          num_coords  = cfg.reg_max * 4)
+                          for level in range(cfg.num_levels)
+                          ])
+        ## proj conv
+        proj_init = torch.arange(cfg.reg_max, dtype=torch.float)
+        self.proj_conv = nn.Conv2d(cfg.reg_max, 1, kernel_size=1, bias=False).requires_grad_(False)
+        self.proj_conv.weight.data[:] = nn.Parameter(proj_init.view([1, cfg.reg_max, 1, 1]), requires_grad=False)
+
+    def forward(self, cls_feats, reg_feats):
+        all_anchors = []
+        all_strides = []
+        all_cls_preds = []
+        all_reg_preds = []
+        all_box_preds = []
+        all_delta_preds = []
+        for level in range(self.num_levels):
+            # -------------- Single-level prediction --------------
+            outputs = self.multi_level_preds[level](cls_feats[level], reg_feats[level])
+
+            # -------------- Decode bbox --------------
+            B, M = outputs["pred_reg"].shape[:2]
+            # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max]
+            delta_pred = outputs["pred_reg"].reshape([B, M, 4, self.reg_max])
+            # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
+            delta_pred = delta_pred.permute(0, 3, 2, 1).contiguous()
+            # [B, reg_max, 4, M] -> [B, 1, 4, M]
+            delta_pred = self.proj_conv(F.softmax(delta_pred, dim=1))
+            # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
+            delta_pred = delta_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
+            ## tlbr -> xyxy
+            x1y1_pred = outputs["anchors"][None] - delta_pred[..., :2] * self.out_stride[level]
+            x2y2_pred = outputs["anchors"][None] + delta_pred[..., 2:] * self.out_stride[level]
+            box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
+
+            # collect results
+            all_cls_preds.append(outputs["pred_cls"])
+            all_reg_preds.append(outputs["pred_reg"])
+            all_box_preds.append(box_pred)
+            all_delta_preds.append(delta_pred)
+            all_anchors.append(outputs["anchors"])
+            all_strides.append(outputs["stride_tensor"])
+        
+        # output dict
+        outputs = {"pred_cls":      all_cls_preds,     # List(Tensor) [B, M, C]
+                   "pred_reg":      all_reg_preds,     # List(Tensor) [B, M, 4*(reg_max)]
+                   "pred_box":      all_box_preds,     # List(Tensor) [B, M, 4]
+                   "pred_delta":    all_delta_preds,   # List(Tensor) [B, M, 4]
+                   "anchors":       all_anchors,       # List(Tensor) [M, 2]
+                   "stride_tensor": all_strides,       # List(Tensor) [M, 1]
+                   "strides":       self.out_stride,   # List(Int) = [8, 16, 32]
+                   }
+
+        return outputs
+
+
+# -------------------- Segmentation Pred Layer --------------------
+## Single-level pred layer
+class SegPredLayer(nn.Module):
+    def __init__(self,
+                 cls_dim     :int = 256,
+                 reg_dim     :int = 256,
+                 seg_dim     :int = 256,
+                 stride      :int = 32,
+                 num_classes :int = 80,
+                 num_coords  :int = 4,
+                 num_masks   :int = 1):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.stride = stride
+        self.cls_dim = cls_dim
+        self.reg_dim = reg_dim
+        self.seg_dim = seg_dim
+        self.num_classes = num_classes
+        self.num_coords = num_coords
+        self.num_masks = num_masks
+
+        # --------- Network Parameters ----------
+        self.cls_pred = nn.Conv2d(cls_dim, num_classes, kernel_size=1)
+        self.reg_pred = nn.Conv2d(reg_dim, num_coords, kernel_size=1)                
+        self.seg_pred = nn.Conv2d(seg_dim, num_masks, kernel_size=1)                
+
+        self.init_bias()
+        
+    def init_bias(self):
+        # cls pred bias
+        b = self.cls_pred.bias.view(1, -1)
+        b.data.fill_(math.log(5 / self.num_classes / (640. / self.stride) ** 2))
+        self.cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        # reg pred bias
+        b = self.reg_pred.bias.view(-1, )
+        b.data.fill_(1.0)
+        self.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        w = self.reg_pred.weight
+        w.data.fill_(0.)
+        self.reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
+        # seg pred bias
+        b = self.seg_pred.bias.view(-1, )
+        b.data.fill_(1.0)
+        self.seg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        w = self.seg_pred.weight
+        w.data.fill_(0.)
+        self.seg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
+
+    def generate_anchors(self, fmp_size):
+        """
+            fmp_size: (List) [H, W]
+        """
+        # generate grid cells
+        fmp_h, fmp_w = fmp_size
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        # [H, W, 2] -> [HW, 2]
+        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
+        anchors += 0.5  # add center offset
+        anchors *= self.stride
+
+        return anchors
+        
+    def forward(self, cls_feat, reg_feat, seg_feat):
+        # pred
+        cls_pred = self.cls_pred(cls_feat)
+        reg_pred = self.reg_pred(reg_feat)
+        seg_pred = self.seg_pred(seg_feat)
+
+        # generate anchor boxes: [M, 4]
+        B, _, H, W = cls_pred.size()
+        fmp_size = [H, W]
+        anchors = self.generate_anchors(fmp_size)
+        anchors = anchors.to(cls_pred.device)
+        # stride tensor: [M, 1]
+        stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride
+        
+        # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
+        cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+        reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_coords)
+        seg_pred = seg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_masks)
+        
+        # output dict
+        outputs = {"pred_cls": cls_pred,            # List(Tensor) [B, M, Nc]
+                   "pred_reg": reg_pred,            # List(Tensor) [B, M, Na]
+                   "pred_seg": seg_pred,            # List(Tensor) [B, M, Nm]
+                   "anchors": anchors,              # List(Tensor) [M, 2]
+                   "strides": self.stride,          # List(Int) = [8, 16, 32]
+                   "stride_tensor": stride_tensor   # List(Tensor) [M, 1]
+                   }
+
+        return outputs
+
+## Multi-level pred layer
+class RTCSegPredLayer(nn.Module):
+    def __init__(self,
+                 cfg,
+                 cls_dim,
+                 reg_dim,
+                 seg_dim,
+                 ):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.cfg = cfg
+        self.cls_dim = cls_dim
+        self.reg_dim = reg_dim
+        self.seg_dim = seg_dim
+
+        # ----------- Network Parameters -----------
+        ## pred layers
+        self.multi_level_preds = nn.ModuleList(
+            [SegPredLayer(cls_dim     = cls_dim,
+                          reg_dim     = reg_dim,
+                          seg_dim     = seg_dim,
+                          stride      = cfg.out_stride[level],
+                          num_classes = cfg.num_classes,
+                          num_coords  = cfg.reg_max * 4,
+                          num_masks   = cfg.mask_dim)
+                          for level in range(cfg.num_levels)
+                          ])
+        ## proj conv
+        proj_init = torch.arange(cfg.reg_max, dtype=torch.float)
+        self.proj_conv = nn.Conv2d(cfg.reg_max, 1, kernel_size=1, bias=False).requires_grad_(False)
+        self.proj_conv.weight.data[:] = nn.Parameter(proj_init.view([1, cfg.reg_max, 1, 1]), requires_grad=False)
+
+    def forward(self, cls_feats, reg_feats, seg_feats):
+        all_anchors = []
+        all_strides = []
+        all_cls_preds = []
+        all_reg_preds = []
+        all_box_preds = []
+        all_seg_preds = []
+        for level in range(self.cfg.num_levels):
+            # -------------- Single-level prediction --------------
+            outputs = self.multi_level_preds[level](cls_feats[level], reg_feats[level], seg_feats[level])
+
+            # -------------- Decode bbox --------------
+            B, M = outputs["pred_reg"].shape[:2]
+            # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max]
+            delta_pred = outputs["pred_reg"].reshape([B, M, 4, self.cfg.reg_max])
+            # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
+            delta_pred = delta_pred.permute(0, 3, 2, 1).contiguous()
+            # [B, reg_max, 4, M] -> [B, 1, 4, M]
+            delta_pred = self.proj_conv(F.softmax(delta_pred, dim=1))
+            # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
+            delta_pred = delta_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
+            ## tlbr -> xyxy
+            x1y1_pred = outputs["anchors"][None] - delta_pred[..., :2] * self.cfg.out_stride[level]
+            x2y2_pred = outputs["anchors"][None] + delta_pred[..., 2:] * self.cfg.out_stride[level]
+            box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
+
+            # collect results
+            all_cls_preds.append(outputs["pred_cls"])
+            all_reg_preds.append(outputs["pred_reg"])
+            all_seg_preds.append(outputs["pred_seg"])
+            all_box_preds.append(box_pred)
+            all_anchors.append(outputs["anchors"])
+            all_strides.append(outputs["stride_tensor"])
+        
+        # output dict
+        outputs = {"pred_cls":      all_cls_preds,         # List(Tensor) [B, M, C]
+                   "pred_reg":      all_reg_preds,         # List(Tensor) [B, M, 4*(reg_max)]
+                   "pred_box":      all_box_preds,         # List(Tensor) [B, M, 4]
+                   "pred_seg":      all_seg_preds,         # List(Tensor) [B, M, 4]
+                   "anchors":       all_anchors,           # List(Tensor) [M, 2]
+                   "stride_tensor": all_strides,           # List(Tensor) [M, 1]
+                   "strides":       self.cfg.out_stride,   # List(Int) = [8, 16, 32]
+                   }
+
+        return outputs

+ 2 - 1
yolo/train.py

@@ -13,7 +13,8 @@ from torch.nn.parallel import DistributedDataParallel as DDP
 
 # ----------------- Extra Components -----------------
 from utils import distributed_utils
-from utils.misc import compute_flops, build_dataloader, CollateFunc, ModelEMA
+from utils.misc import compute_flops, build_dataloader, CollateFunc
+from utils.ema  import ModelEMA
 
 # ----------------- Config Components -----------------
 from config import build_config

+ 69 - 0
yolo/utils/ema.py

@@ -0,0 +1,69 @@
+# =====================================================================
+# Copyright 2021 RangiLyu. All rights reserved.
+# =====================================================================
+# Modified from: https://github.com/facebookresearch/d2go
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# Licensed under the Apache License, Version 2.0 (the "License")
+import math
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+
+
+# Modified from the YOLOv5 project
+class ModelEMA(object):
+    def __init__(self, model, ema_decay=0.9999, ema_tau=2000, resume=None):
+        # Create EMA
+        self.ema = deepcopy(self.de_parallel(model)).eval()  # FP32 EMA
+        self.updates = 0  # number of EMA updates
+        self.decay = lambda x: ema_decay * (1 - math.exp(-x / ema_tau))  # decay exponential ramp (to help early epochs)
+        for p in self.ema.parameters():
+            p.requires_grad_(False)
+
+        if resume is not None and resume.lower() != "none":
+            self.load_resume(resume)
+
+        print("Initialize ModelEMA's updates: {}".format(self.updates))
+
+    def load_resume(self, resume):
+        checkpoint = torch.load(resume)
+        if 'model_ema' in checkpoint.keys():
+            print('--Load ModelEMA state dict from the checkpoint: ', resume)
+            model_ema_state_dict = checkpoint["model_ema"]
+            self.ema.load_state_dict(model_ema_state_dict)
+        if 'ema_updates' in checkpoint.keys():
+            print('--Load ModelEMA updates from the checkpoint: ', resume)
+            # checkpoint state dict
+            self.updates = checkpoint.pop("ema_updates")
+
+    def is_parallel(self, model):
+        # Returns True if model is of type DP or DDP
+        return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
+
+    def de_parallel(self, model):
+        # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
+        return model.module if self.is_parallel(model) else model
+
+    def copy_attr(self, a, b, include=(), exclude=()):
+        # Copy attributes from b to a, options to only include [...] and to exclude [...]
+        for k, v in b.__dict__.items():
+            if (len(include) and k not in include) or k.startswith('_') or k in exclude:
+                continue
+            else:
+                setattr(a, k, v)
+
+    def update(self, model):
+        # Update EMA parameters
+        self.updates += 1
+        d = self.decay(self.updates)
+
+        msd = self.de_parallel(model).state_dict()  # model state_dict
+        for k, v in self.ema.state_dict().items():
+            if v.dtype.is_floating_point:  # true for FP16 and FP32
+                v *= d
+                v += (1 - d) * msd[k].detach()
+
+    def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
+        # Update EMA attributes
+        self.copy_attr(self.ema, model, include, exclude)

+ 6 - 1
yolo/utils/misc.py

@@ -354,7 +354,12 @@ def load_weight(model, path_to_ckpt, fuse_cbn=False, rep_conv=False):
         print('Epoch: {}'.format(checkpoint["epoch"]))
         print('mAP: {}'.format(checkpoint["mAP"]))
         print('--------------------------------------')
-        checkpoint_state_dict = checkpoint["model"]
+        if "model_ema" in checkpoint:
+            print("Load the model from the ModelEMA state dict ...")
+            checkpoint_state_dict = checkpoint["model_ema"]
+        else:
+            print("Load the model state dict ...")
+            checkpoint_state_dict = checkpoint["model"]
         model.load_state_dict(checkpoint_state_dict)
 
         print('Finished loading model!')