Browse Source

modify RTCDet, according to YOLOv8

yjh0410 2 years ago
parent
commit
f21372262d

+ 63 - 69
config/model_config/rtcdet_config.py

@@ -42,20 +42,19 @@ rtcdet_cfg = {
         # ---------------- Train config ----------------
         ## Input
         'multi_scale': [0.4, 1.0], # 256 -> 640
-        'trans_type': 'yolox_pico',
+        'trans_type': 'yolov5_pico',
         # ---------------- Assignment config ----------------
         ## Matcher
-        'matcher': "simota",
-        'matcher_hpy': {'center_sampling_radius': 2.5,
-                        'topk_candidate': 10},
+        'matcher': "tal",
+        'matcher_hpy': {'topk_candidates': 10,
+                        'alpha': 0.5,
+                        'beta':  6.0},
         # ---------------- Loss config ----------------
-        'cls_loss': 'bce',
-        'loss_cls_weight': 1.0,
-        'loss_dfl_weight': 1.0,
-        'loss_box_weight': 5.0,
-        'loss_box_aux': True,
+        'loss_cls_weight': 0.5,
+        'loss_dfl_weight': 1.5,
+        'loss_box_weight': 7.5,
         # ---------------- Train config ----------------
-        'trainer_type': 'rtcdet',
+        'trainer_type': 'yolov8',
     },
 
     'rtcdet_n':{
@@ -98,20 +97,19 @@ rtcdet_cfg = {
         # ---------------- Train config ----------------
         ## Input
         'multi_scale': [0.5, 1.5], # 320 -> 960
-        'trans_type': 'yolox_nano',
+        'trans_type': 'yolov5_nano',
         # ---------------- Assignment config ----------------
         ## Matcher
-        'matcher': "simota",
-        'matcher_hpy': {'center_sampling_radius': 2.5,
-                        'topk_candidate': 10},
+        'matcher': "tal",
+        'matcher_hpy': {'topk_candidates': 10,
+                        'alpha': 0.5,
+                        'beta':  6.0},
         # ---------------- Loss config ----------------
-        'cls_loss': 'qfl',
-        'loss_cls_weight': 1.0,
-        'loss_dfl_weight': 1.0,
-        'loss_box_weight': 5.0,
-        'loss_box_aux': True,
+        'loss_cls_weight': 0.5,
+        'loss_dfl_weight': 1.5,
+        'loss_box_weight': 7.5,
         # ---------------- Train config ----------------
-        'trainer_type': 'rtcdet',
+        'trainer_type': 'yolov8',
     },
 
     'rtcdet_t':{
@@ -154,20 +152,19 @@ rtcdet_cfg = {
         # ---------------- Train config ----------------
         ## Input
         'multi_scale': [0.5, 1.5], # 320 -> 960
-        'trans_type': 'yolox_small',
+        'trans_type': 'yolov5_small',
         # ---------------- Assignment config ----------------
         ## Matcher
-        'matcher': "simota",
-        'matcher_hpy': {'center_sampling_radius': 2.5,
-                        'topk_candidate': 10},
+        'matcher': "tal",
+        'matcher_hpy': {'topk_candidates': 10,
+                        'alpha': 0.5,
+                        'beta':  6.0},
         # ---------------- Loss config ----------------
-        'cls_loss': 'bce',
-        'loss_cls_weight': 1.0,
-        'loss_dfl_weight': 1.0,
-        'loss_box_weight': 5.0,
-        'loss_box_aux': True,
+        'loss_cls_weight': 0.5,
+        'loss_dfl_weight': 1.5,
+        'loss_box_weight': 7.5,
         # ---------------- Train config ----------------
-        'trainer_type': 'rtcdet',
+        'trainer_type': 'yolov8',
     },
 
     'rtcdet_s':{
@@ -210,19 +207,19 @@ rtcdet_cfg = {
         # ---------------- Train config ----------------
         ## Input
         'multi_scale': [0.5, 1.5], # 320 -> 960
-        'trans_type': 'yolox_small',
+        'trans_type': 'yolov5_small',
         # ---------------- Assignment config ----------------
         ## Matcher
-        'matcher': "simota",
-        'matcher_hpy': {'center_sampling_radius': 2.5,
-                        'topk_candidate': 10},
+        'matcher': "tal",
+        'matcher_hpy': {'topk_candidates': 10,
+                        'alpha': 0.5,
+                        'beta':  6.0},
         # ---------------- Loss config ----------------
-        'loss_cls_weight': 1.0,
-        'loss_dfl_weight': 1.0,
-        'loss_box_weight': 5.0,
-        'loss_box_aux': True,
+        'loss_cls_weight': 0.5,
+        'loss_dfl_weight': 1.5,
+        'loss_box_weight': 7.5,
         # ---------------- Train config ----------------
-        'trainer_type': 'rtcdet',
+        'trainer_type': 'yolov8',
     },
 
     'rtcdet_m':{
@@ -265,20 +262,19 @@ rtcdet_cfg = {
         # ---------------- Train config ----------------
         ## Input
         'multi_scale': [0.5, 1.5], # 320 -> 960
-        'trans_type': 'yolox_medium',
+        'trans_type': 'yolov5_medium',
         # ---------------- Assignment config ----------------
         ## Matcher
-        'matcher': "simota",
-        'matcher_hpy': {'center_sampling_radius': 2.5,
-                        'topk_candidate': 10},
+        'matcher': "tal",
+        'matcher_hpy': {'topk_candidates': 10,
+                        'alpha': 0.5,
+                        'beta':  6.0},
         # ---------------- Loss config ----------------
-        'cls_loss': 'bce',
-        'loss_cls_weight': 1.0,
-        'loss_dfl_weight': 1.0,
-        'loss_box_weight': 5.0,
-        'loss_box_aux': True,
+        'loss_cls_weight': 0.5,
+        'loss_dfl_weight': 1.5,
+        'loss_box_weight': 7.5,
         # ---------------- Train config ----------------
-        'trainer_type': 'rtcdet',
+        'trainer_type': 'yolov8',
     },
 
     'rtcdet_l':{
@@ -321,20 +317,19 @@ rtcdet_cfg = {
         # ---------------- Train config ----------------
         ## Input
         'multi_scale': [0.5, 1.5], # 320 -> 960
-        'trans_type': 'yolox_large',
+        'trans_type': 'yolov5_large',
         # ---------------- Assignment config ----------------
         ## Matcher
-        'matcher': "simota",
-        'matcher_hpy': {'center_sampling_radius': 2.5,
-                        'topk_candidate': 10},
+        'matcher': "tal",
+        'matcher_hpy': {'topk_candidates': 10,
+                        'alpha': 0.5,
+                        'beta':  6.0},
         # ---------------- Loss config ----------------
-        'cls_loss': 'bce',
-        'loss_cls_weight': 1.0,
-        'loss_dfl_weight': 1.0,
-        'loss_box_weight': 5.0,
-        'loss_box_aux': True,
+        'loss_cls_weight': 0.5,
+        'loss_dfl_weight': 1.5,
+        'loss_box_weight': 7.5,
         # ---------------- Train config ----------------
-        'trainer_type': 'rtcdet',
+        'trainer_type': 'yolov8',
     },
 
     'rtcdet_x':{
@@ -377,20 +372,19 @@ rtcdet_cfg = {
         # ---------------- Train config ----------------
         ## Input
         'multi_scale': [0.5, 1.5], # 320 -> 960
-        'trans_type': 'yolox_huge',
+        'trans_type': 'yolov5_huge',
         # ---------------- Assignment config ----------------
         ## Matcher
-        'matcher': "simota",
-        'matcher_hpy': {'center_sampling_radius': 2.5,
-                        'topk_candidate': 10},
+        'matcher': "tal",
+        'matcher_hpy': {'topk_candidates': 10,
+                        'alpha': 0.5,
+                        'beta':  6.0},
         # ---------------- Loss config ----------------
-        'cls_loss': 'bce',
-        'loss_cls_weight': 1.0,
-        'loss_dfl_weight': 1.0,
-        'loss_box_weight': 5.0,
-        'loss_box_aux': True,
+        'loss_cls_weight': 0.5,
+        'loss_dfl_weight': 1.5,
+        'loss_box_weight': 7.5,
         # ---------------- Train config ----------------
-        'trainer_type': 'rtcdet',
+        'trainer_type': 'yolov8',
     },
 
 }

+ 2 - 2
models/detectors/rtcdet/README.md

@@ -4,8 +4,8 @@
 |----------|-------|-------|-------------------------|--------------------|------------------------|-------------------|-------------------|--------------------|--------|
 | RTCDet-N |  640  | 8xb16 |                         |                    |                        |                   |                   |                    |  |
 | RTCDet-T |  640  | 8xb16 |                         |                    |                        |                   |                   |                    |  |
-| RTCDet-S |  640  | 8xb16 |                         |                    |           44.5         |       63.5        |        30.9       |         8.5        | [ckpt](https://github.com/yjh0410/RT-ODLab/releases/download/yolo_tutorial_ckpt/rtcdet_s_coco.pth) |
-| RTCDet-M |  640  | 8xb16 |                         |                    |           48.7         |       67.6        |        80.3       |         22.6       | [ckpt](https://github.com/yjh0410/RT-ODLab/releases/download/yolo_tutorial_ckpt/rtcdet_m_coco.pth) |
+| RTCDet-S |  640  | 8xb16 |                         |                    |                        |                   |                   |                    |  |
+| RTCDet-M |  640  | 8xb16 |                         |                    |                        |                   |                   |                    |  |
 | RTCDet-L |  640  | 8xb16 |                         |                    |                        |                   |                   |                    |  |
 | RTCDet-X |  640  | 8xb16 |                         |                    |                        |                   |                   |                    |  |
 

+ 2 - 1
models/detectors/rtcdet/build.py

@@ -36,5 +36,6 @@ def build_rtcdet(args, cfg, device, num_classes=80, trainable=False, deploy=Fals
     criterion = None
     if trainable:
         # build criterion for training
-        criterion = build_criterion(args, cfg, device, num_classes)
+        criterion = build_criterion(cfg, device, num_classes)
+        
     return model, criterion

+ 185 - 94
models/detectors/rtcdet/loss.py

@@ -1,44 +1,43 @@
 import torch
+import torch.nn as nn
 import torch.nn.functional as F
 
-from utils.box_ops import bbox2dist, get_ious
+from utils.box_ops import bbox2dist, bbox_iou
 from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
 
-from .matcher import SimOTA
+from .matcher import TaskAlignedAssigner
 
 
-# ----------------------- Criterion for training -----------------------
 class Criterion(object):
-    def __init__(self, args, cfg, device, num_classes=80):
+    def __init__(self, cfg, device, num_classes=80):
+        # --------------- Basic parameters ---------------
         self.cfg = cfg
-        self.args = args
         self.device = device
         self.num_classes = num_classes
-        self.max_epoch = args.max_epoch
-        self.no_aug_epoch = args.no_aug_epoch
-        # ---------------- Loss weight ----------------
-        self.loss_box_aux    = cfg['loss_box_aux']
+        self.reg_max = cfg['reg_max']
+        self.use_dfl = cfg['reg_max'] > 1
+        # --------------- Loss config ---------------
         self.loss_cls_weight = cfg['loss_cls_weight']
         self.loss_box_weight = cfg['loss_box_weight']
         self.loss_dfl_weight = cfg['loss_dfl_weight']
-        # ---------------- Matcher ----------------
-        ## Aligned SimOTA assigner
+        # --------------- Matcher config ---------------
         self.matcher_hpy = cfg['matcher_hpy']
-        self.matcher = SimOTA(num_classes            = num_classes,
-                              center_sampling_radius = self.matcher_hpy['center_sampling_radius'],
-                              topk_candidate         = self.matcher_hpy['topk_candidate'])
+        self.matcher = TaskAlignedAssigner(num_classes     = num_classes,
+                                           topk_candidates = self.matcher_hpy['topk_candidates'],
+                                           alpha           = self.matcher_hpy['alpha'],
+                                           beta            = self.matcher_hpy['beta']
+                                           )
 
-    # ----------------- Loss functions -----------------
     def loss_classes(self, pred_cls, gt_score):
         # compute bce loss
         loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_score, reduction='none')
 
         return loss_cls
-
-    def loss_bboxes(self, pred_box, gt_box):
+    
+    def loss_bboxes(self, pred_box, gt_box, bbox_weight):
         # regression loss
-        ious = get_ious(pred_box, gt_box, 'xyxy', 'giou')
-        loss_box = 1.0 - ious
+        ious = bbox_iou(pred_box, gt_box, xywh=False, CIoU=True)
+        loss_box = (1.0 - ious.squeeze(-1)) * bbox_weight
 
         return loss_box
     
@@ -74,89 +73,90 @@ class Criterion(object):
 
         return loss_dfl
 
-    def loss_bboxes_aux(self, pred_delta, gt_box, anchors, stride_tensors):
-        gt_delta_tl = (anchors - gt_box[..., :2]) / stride_tensors
-        gt_delta_rb = (gt_box[..., 2:] - anchors) / stride_tensors
-        gt_delta = torch.cat([gt_delta_tl, gt_delta_rb], dim=1)
-        loss_box_aux = F.l1_loss(pred_delta, gt_delta, reduction='none')
-
-        return loss_box_aux
-    
-    # ----------------- Main process -----------------
-    def __call__(self, outputs, targets, epoch=0):
+    def __call__(self, outputs, targets, epoch=0):        
+        """
+            outputs['pred_cls']: List(Tensor) [B, M, C]
+            outputs['pred_reg']: List(Tensor) [B, M, 4*(reg_max+1)]
+            outputs['pred_box']: List(Tensor) [B, M, 4]
+            outputs['anchors']: List(Tensor) [M, 2]
+            outputs['strides']: List(Int) [8, 16, 32] output stride
+            outputs['stride_tensor']: List(Tensor) [M, 1]
+            targets: (List) [dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}, ...]
+        """
         bs = outputs['pred_cls'][0].shape[0]
         device = outputs['pred_cls'][0].device
-        fpn_strides = outputs['strides']
+        strides = outputs['stride_tensor']
         anchors = outputs['anchors']
-        num_anchors = sum([ab.shape[0] for ab in anchors])
+        anchors = torch.cat(anchors, dim=0)
+        num_anchors = anchors.shape[0]
+
         # preds: [B, M, C]
         cls_preds = torch.cat(outputs['pred_cls'], dim=1)
         reg_preds = torch.cat(outputs['pred_reg'], dim=1)
         box_preds = torch.cat(outputs['pred_box'], dim=1)
-
+        
         # --------------- label assignment ---------------
-        cls_targets = []
-        box_targets = []
+        gt_score_targets = []
+        gt_bbox_targets = []
         fg_masks = []
         for batch_idx in range(bs):
-            tgt_labels = targets[batch_idx]["labels"].to(device)
-            tgt_bboxes = targets[batch_idx]["boxes"].to(device)
+            tgt_labels = targets[batch_idx]["labels"].to(device)     # [Mp,]
+            tgt_boxs = targets[batch_idx]["boxes"].to(device)        # [Mp, 4]
 
             # check target
-            if len(tgt_labels) == 0 or tgt_bboxes.max().item() == 0.:
+            if len(tgt_labels) == 0 or tgt_boxs.max().item() == 0.:
                 # There is no valid gt
-                cls_target = cls_preds.new_zeros((num_anchors, self.num_classes))
-                box_target = cls_preds.new_zeros((0, 4))
-                fg_mask = cls_preds.new_zeros(num_anchors).bool()
+                fg_mask = cls_preds.new_zeros(1, num_anchors).bool()               #[1, M,]
+                gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
+                gt_box = cls_preds.new_zeros((1, num_anchors, 4))                  #[1, M, 4]
             else:
+                tgt_labels = tgt_labels[None, :, None]      # [1, Mp, 1]
+                tgt_boxs = tgt_boxs[None]                   # [1, Mp, 4]
                 (
-                    fg_mask,
-                    assigned_labels,
-                    assigned_ious,
-                    assigned_indexs
+                    _,
+                    gt_box,     # [1, M, 4]
+                    gt_score,   # [1, M, C]
+                    fg_mask,    # [1, M,]
+                    _
                 ) = self.matcher(
-                    fpn_strides = fpn_strides,
-                    anchors = anchors,
-                    pred_cls = cls_preds[batch_idx], 
-                    pred_box = box_preds[batch_idx],
-                    tgt_labels = tgt_labels,
-                    tgt_bboxes = tgt_bboxes
+                    pd_scores = cls_preds[batch_idx:batch_idx+1].detach().sigmoid(), 
+                    pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
+                    anc_points = anchors,
+                    gt_labels = tgt_labels,
+                    gt_bboxes = tgt_boxs
                     )
-                # prepare cls targets
-                assigned_labels = F.one_hot(assigned_labels.long(), self.num_classes)
-                assigned_labels = assigned_labels * assigned_ious.unsqueeze(-1)
-                cls_target = assigned_labels.new_zeros((num_anchors, self.num_classes))
-                cls_target[fg_mask] = assigned_labels
-                # prepare box targets
-                box_target = tgt_bboxes[assigned_indexs]
-
-            cls_targets.append(cls_target)
-            box_targets.append(box_target)
+            gt_score_targets.append(gt_score)
+            gt_bbox_targets.append(gt_box)
             fg_masks.append(fg_mask)
 
-        cls_targets = torch.cat(cls_targets, 0)
-        box_targets = torch.cat(box_targets, 0)
-        fg_masks = torch.cat(fg_masks, 0)
-        num_fgs = fg_masks.sum()
-
-        # average loss normalizer across all the GPUs
+        # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
+        fg_masks = torch.cat(fg_masks, 0).view(-1)                                    # [BM,]
+        gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes)  # [BM, C]
+        gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4)                   # [BM, 4]
+        num_fgs = gt_score_targets.sum()
+        
+        # Average loss normalizer across all the GPUs
         if is_dist_avail_and_initialized():
             torch.distributed.all_reduce(num_fgs)
         num_fgs = (num_fgs / get_world_size()).clamp(1.0)
-        
+
         # ------------------ Classification loss ------------------
         cls_preds = cls_preds.view(-1, self.num_classes)
-        loss_cls = self.loss_classes(cls_preds, cls_targets)
+        loss_cls = self.loss_classes(cls_preds, gt_score_targets)
         loss_cls = loss_cls.sum() / num_fgs
 
         # ------------------ Regression loss ------------------
         box_preds_pos = box_preds.view(-1, 4)[fg_masks]
-        loss_box = self.loss_bboxes(box_preds_pos, box_targets)
+        box_targets_pos = gt_bbox_targets.view(-1, 4)[fg_masks]
+        bbox_weight = gt_score_targets[fg_masks].sum(-1)
+        loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos, bbox_weight)
         loss_box = loss_box.sum() / num_fgs
 
         # ------------------ Distribution focal loss  ------------------
         ## process anchors
-        anchors = torch.cat(anchors, dim=0)
+        anchors = torch.cat(outputs['anchors'], dim=0)
         anchors = anchors[None].repeat(bs, 1, 1).view(-1, 2)
         ## process stride tensors
         strides = torch.cat(outputs['stride_tensor'], dim=0)
@@ -166,40 +166,131 @@ class Criterion(object):
         anchors_pos = anchors[fg_masks]
         strides_pos = strides[fg_masks]
         ## compute dfl
-        loss_dfl = self.loss_dfl(reg_preds_pos, box_targets, anchors_pos, strides_pos)
+        loss_dfl = self.loss_dfl(reg_preds_pos, box_targets_pos, anchors_pos, strides_pos, bbox_weight)
         loss_dfl = loss_dfl.sum() / num_fgs
 
         # total loss
-        losses = self.loss_cls_weight * loss_cls + \
-                 self.loss_box_weight * loss_box + \
-                 self.loss_dfl_weight * loss_dfl
-
-        loss_dict = dict(
-                loss_cls = loss_cls,
-                loss_box = loss_box,
-                loss_dfl = loss_dfl,
-                losses = losses
-        )
+        if not self.use_dfl:
+            losses = loss_cls * self.loss_cls_weight + loss_box * self.loss_box_weight
+            loss_dict = dict(
+                    loss_cls = loss_cls,
+                    loss_box = loss_box,
+                    losses = losses
+            )
+        else:
+            losses = loss_cls * self.loss_cls_weight + loss_box * self.loss_box_weight + loss_dfl * self.loss_dfl_weight
+            loss_dict = dict(
+                    loss_cls = loss_cls,
+                    loss_box = loss_box,
+                    loss_dfl = loss_dfl,
+                    losses = losses
+            )
+
+        return loss_dict
+    
 
-        # ------------------ Aux regression loss ------------------
-        if epoch >= (self.max_epoch - self.no_aug_epoch - 1) and self.loss_box_aux:
-            ## delta_preds
-            delta_preds = torch.cat(outputs['pred_delta'], dim=1)
-            delta_preds_pos = delta_preds.view(-1, 4)[fg_masks]
-            ## aux loss
-            loss_box_aux = self.loss_bboxes_aux(delta_preds_pos, box_targets, anchors_pos, strides_pos)
-            loss_box_aux = loss_box_aux.sum() / num_fgs
+class ClassificationLoss(nn.Module):
+    def __init__(self, cfg, reduction='none'):
+        super(ClassificationLoss, self).__init__()
+        self.cfg = cfg
+        self.reduction = reduction
+        # For VFL
+        self.alpha = 0.75
+        self.gamma = 2.0
 
-            losses += loss_box_aux
-            loss_dict['loss_box_aux'] = loss_box_aux
 
+    def binary_cross_entropy(self, pred_logits, gt_score):
+        loss = F.binary_cross_entropy_with_logits(
+            pred_logits.float(), gt_score.float(), reduction='none')
 
-        return loss_dict
+        if self.reduction == 'sum':
+            loss = loss.sum()
+        elif self.reduction == 'mean':
+            loss = loss.mean()
+
+        return loss
+
+
+    def forward(self, pred_logits, gt_score):
+        if self.cfg['cls_loss'] == 'bce':
+            return self.binary_cross_entropy(pred_logits, gt_score)
+
+
+class RegressionLoss(nn.Module):
+    def __init__(self, num_classes, reg_max, use_dfl):
+        super(RegressionLoss, self).__init__()
+        self.num_classes = num_classes
+        self.reg_max = reg_max
+        self.use_dfl = use_dfl
+
+
+    def df_loss(self, pred_regs, target):
+        gt_left = target.to(torch.long)
+        gt_right = gt_left + 1
+        weight_left = gt_right.to(torch.float) - target
+        weight_right = 1 - weight_left
+        # loss left
+        loss_left = F.cross_entropy(
+            pred_regs.view(-1, self.reg_max + 1),
+            gt_left.view(-1),
+            reduction='none').view(gt_left.shape) * weight_left
+        # loss right
+        loss_right = F.cross_entropy(
+            pred_regs.view(-1, self.reg_max + 1),
+            gt_right.view(-1),
+            reduction='none').view(gt_left.shape) * weight_right
+
+        loss = (loss_left + loss_right).mean(-1, keepdim=True)
         
+        return loss
+
+
+    def forward(self, pred_regs, pred_boxs, anchors, gt_boxs, bbox_weight, fg_masks, strides):
+        """
+        Input:
+            pred_regs: (Tensor) [BM, 4*(reg_max + 1)]
+            pred_boxs: (Tensor) [BM, 4]
+            anchors: (Tensor) [BM, 2]
+            gt_boxs: (Tensor) [BM, 4]
+            bbox_weight: (Tensor) [BM, 1]
+            fg_masks: (Tensor) [BM,]
+            strides: (Tensor) [BM, 1]
+        """
+        # select positive samples mask
+        num_pos = fg_masks.sum()
+
+        if num_pos > 0:
+            pred_boxs_pos = pred_boxs[fg_masks]
+            gt_boxs_pos = gt_boxs[fg_masks]
+
+            # iou loss
+            ious = bbox_iou(pred_boxs_pos,
+                            gt_boxs_pos,
+                            xywh=False,
+                            CIoU=True)
+            loss_iou = (1.0 - ious) * bbox_weight
+               
+            # dfl loss
+            if self.use_dfl:
+                pred_regs_pos = pred_regs[fg_masks]
+                gt_boxs_s = gt_boxs / strides
+                anchors_s = anchors / strides
+                gt_ltrb_s = bbox2dist(anchors_s, gt_boxs_s, self.reg_max)
+                gt_ltrb_s_pos = gt_ltrb_s[fg_masks]
+                loss_dfl = self.df_loss(pred_regs_pos, gt_ltrb_s_pos)
+                loss_dfl *= bbox_weight
+            else:
+                loss_dfl = pred_regs.sum() * 0.
+
+        else:
+            loss_iou = pred_regs.sum() * 0.
+            loss_dfl = pred_regs.sum() * 0.
+
+        return loss_iou, loss_dfl
+
 
-def build_criterion(args, cfg, device, num_classes):
+def build_criterion(cfg, device, num_classes):
     criterion = Criterion(
-        args=args,
         cfg=cfg,
         device=device,
         num_classes=num_classes

+ 169 - 173
models/detectors/rtcdet/matcher.py

@@ -1,180 +1,176 @@
 import torch
+import torch.nn as nn
 import torch.nn.functional as F
-from utils.box_ops import *
-
-
-# -------------------------- YOLOX's SimOTA Assigner --------------------------
-## Simple OTA
-class SimOTA(object):
-    """
-        This code referenced to https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/models/yolo_head.py
-    """
-    def __init__(self, num_classes, center_sampling_radius, topk_candidate ):
+from utils.box_ops import bbox_iou
+
+
+# -------------------------- Task Aligned Assigner --------------------------
+class TaskAlignedAssigner(nn.Module):
+    def __init__(self,
+                 num_classes     = 80,
+                 topk_candidates = 10,
+                 alpha           = 0.5,
+                 beta            = 6.0, 
+                 eps             = 1e-9):
+        super(TaskAlignedAssigner, self).__init__()
+        self.topk_candidates = topk_candidates
         self.num_classes = num_classes
-        self.center_sampling_radius = center_sampling_radius
-        self.topk_candidate = topk_candidate
-
+        self.bg_idx = num_classes
+        self.alpha = alpha
+        self.beta = beta
+        self.eps = eps
 
     @torch.no_grad()
-    def __call__(self, 
-                 fpn_strides, 
-                 anchors, 
-                 pred_cls, 
-                 pred_box, 
-                 tgt_labels,
-                 tgt_bboxes):
-        # [M,]
-        strides_tensor = torch.cat([torch.ones_like(anchor_i[:, 0]) * stride_i
-                                for stride_i, anchor_i in zip(fpn_strides, anchors)], dim=-1)
-        # List[F, M, 2] -> [M, 2]
-        anchors = torch.cat(anchors, dim=0)
-        num_anchor = anchors.shape[0]        
-        num_gt = len(tgt_labels)
-
-        # ----------------------- Find inside points -----------------------
-        fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(
-            tgt_bboxes, anchors, strides_tensor, num_anchor, num_gt)
-        cls_preds = pred_cls[fg_mask].float()   # [Mp, C]
-        box_preds = pred_box[fg_mask].float()   # [Mp, 4]
-
-        # ----------------------- Reg cost -----------------------
-        pair_wise_ious, _ = box_iou(tgt_bboxes, box_preds)      # [N, Mp]
-        reg_cost = -torch.log(pair_wise_ious + 1e-8)            # [N, Mp]
-
-        # ----------------------- Cls cost -----------------------
-        with torch.cuda.amp.autocast(enabled=False):
-            # [Mp, C] -> [N, Mp, C]
-            cls_preds_expand = cls_preds.unsqueeze(0).repeat(num_gt, 1, 1)
-            # prepare cls_target
-            cls_targets = F.one_hot(tgt_labels.long(), self.num_classes).float()
-            cls_targets = cls_targets.unsqueeze(1).repeat(1, cls_preds_expand.size(1), 1)
-            cls_targets *= pair_wise_ious.unsqueeze(-1)  # iou-aware
-            # [N, Mp]
-            cls_cost = F.binary_cross_entropy_with_logits(cls_preds_expand, cls_targets, reduction="none").sum(-1)
-        del cls_preds_expand
-
-        #----------------------- Dynamic K-Matching -----------------------
-        cost_matrix = (
-            cls_cost
-            + 3.0 * reg_cost
-            + 100000.0 * (~is_in_boxes_and_center)
-        ) # [N, Mp]
-
-        (
-            assigned_labels,         # [num_fg,]
-            assigned_ious,           # [num_fg,]
-            assigned_indexs,         # [num_fg,]
-        ) = self.dynamic_k_matching(
-            cost_matrix,
-            pair_wise_ious,
-            tgt_labels,
-            num_gt,
-            fg_mask
-            )
-        del cls_cost, cost_matrix, pair_wise_ious, reg_cost
-
-        return fg_mask, assigned_labels, assigned_ious, assigned_indexs
-
-
-    def get_in_boxes_info(
-        self,
-        gt_bboxes,   # [N, 4]
-        anchors,     # [M, 2]
-        strides,     # [M,]
-        num_anchors, # M
-        num_gt,      # N
-        ):
-        # anchor center
-        x_centers = anchors[:, 0]
-        y_centers = anchors[:, 1]
-
-        # [M,] -> [1, M] -> [N, M]
-        x_centers = x_centers.unsqueeze(0).repeat(num_gt, 1)
-        y_centers = y_centers.unsqueeze(0).repeat(num_gt, 1)
-
-        # [N,] -> [N, 1] -> [N, M]
-        gt_bboxes_l = gt_bboxes[:, 0].unsqueeze(1).repeat(1, num_anchors) # x1
-        gt_bboxes_t = gt_bboxes[:, 1].unsqueeze(1).repeat(1, num_anchors) # y1
-        gt_bboxes_r = gt_bboxes[:, 2].unsqueeze(1).repeat(1, num_anchors) # x2
-        gt_bboxes_b = gt_bboxes[:, 3].unsqueeze(1).repeat(1, num_anchors) # y2
-
-        b_l = x_centers - gt_bboxes_l
-        b_r = gt_bboxes_r - x_centers
-        b_t = y_centers - gt_bboxes_t
-        b_b = gt_bboxes_b - y_centers
-        bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
-
-        is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
-        is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
-        # in fixed center
-        center_radius = self.center_sampling_radius
-
-        # [N, 2]
-        gt_centers = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) * 0.5
+    def forward(self,
+                pd_scores,
+                pd_bboxes,
+                anc_points,
+                gt_labels,
+                gt_bboxes):
+        self.bs = pd_scores.size(0)
+        self.n_max_boxes = gt_bboxes.size(1)
+
+        mask_pos, align_metric, overlaps = self.get_pos_mask(
+            pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points)
+
+        target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(
+            mask_pos, overlaps, self.n_max_boxes)
+
+        # assigned target
+        target_labels, target_bboxes, target_scores = self.get_targets(
+            gt_labels, gt_bboxes, target_gt_idx, fg_mask)
+
+        # normalize
+        align_metric *= mask_pos
+        pos_align_metrics = align_metric.amax(axis=-1, keepdim=True)  # b, max_num_obj
+        pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True)  # b, max_num_obj
+        norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
+        target_scores = target_scores * norm_align_metric
+
+        return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
+
+    def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points):
+        # get anchor_align metric, (b, max_num_obj, h*w)
+        align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes)
+        # get in_gts mask, (b, max_num_obj, h*w)
+        mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
+        # get topk_metric mask, (b, max_num_obj, h*w)
+        mask_topk = self.select_topk_candidates(align_metric * mask_in_gts)
+        # merge all mask to a final mask, (b, max_num_obj, h*w)
+        mask_pos = mask_topk * mask_in_gts
+
+        return mask_pos, align_metric, overlaps
+
+    def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes):
+        ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
+        ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes)  # b, max_num_obj
+        ind[1] = gt_labels.long().squeeze(-1)  # b, max_num_obj
+        # get the scores of each grid for each gt cls
+        bbox_scores = pd_scores[ind[0], :, ind[1]]  # b, max_num_obj, h*w
+
+        overlaps = bbox_iou(gt_bboxes.unsqueeze(2), pd_bboxes.unsqueeze(1), xywh=False,
+                            CIoU=True).squeeze(3).clamp(0)
+        align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
+
+        return align_metric, overlaps
+
+    def select_topk_candidates(self, metrics, largest=True):
+        """
+        Args:
+            metrics: (b, max_num_obj, h*w).
+            topk_mask: (b, max_num_obj, topk) or None
+        """
+        num_anchors = metrics.shape[-1]  # h*w
+        # (b, max_num_obj, topk)
+        topk_metrics, topk_idxs = torch.topk(metrics, self.topk_candidates, dim=-1, largest=largest)
+        topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).tile([1, 1, self.topk_candidates])
+        # (b, max_num_obj, topk)
+        topk_idxs[~topk_mask] = 0
+        # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
+        is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2)
+        # filter invalid bboxes
+        is_in_topk = torch.where(is_in_topk > 1, 0, is_in_topk)
         
-        # [1, M]
-        center_radius_ = center_radius * strides.unsqueeze(0)
-
-        gt_bboxes_l = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # x1
-        gt_bboxes_t = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # y1
-        gt_bboxes_r = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # x2
-        gt_bboxes_b = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # y2
-
-        c_l = x_centers - gt_bboxes_l
-        c_r = gt_bboxes_r - x_centers
-        c_t = y_centers - gt_bboxes_t
-        c_b = gt_bboxes_b - y_centers
-        center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
-        is_in_centers = center_deltas.min(dim=-1).values > 0.0
-        is_in_centers_all = is_in_centers.sum(dim=0) > 0
-
-        # in boxes and in centers
-        is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
-
-        is_in_boxes_and_center = (
-            is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
-        )
-        return is_in_boxes_anchor, is_in_boxes_and_center
-    
+        return is_in_topk.to(metrics.dtype)
+
+    def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
+        # assigned target labels, (b, 1)
+        batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
+        target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes  # (b, h*w)
+        target_labels = gt_labels.long().flatten()[target_gt_idx]  # (b, h*w)
+
+        # assigned target boxes, (b, max_num_obj, 4) -> (b, h*w)
+        target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
+
+        # assigned target scores
+        target_labels.clamp(0)
+        target_scores = F.one_hot(target_labels, self.num_classes)  # (b, h*w, 80)
+        fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)  # (b, h*w, 80)
+        target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
+
+        return target_labels, target_bboxes, target_scores
     
-    def dynamic_k_matching(
-        self, 
-        cost, 
-        pair_wise_ious, 
-        gt_classes, 
-        num_gt, 
-        fg_mask
-        ):
-        # Dynamic K
-        # ---------------------------------------------------------------
-        matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
-
-        ious_in_boxes_matrix = pair_wise_ious
-        n_candidate_k = min(self.topk_candidate, ious_in_boxes_matrix.size(1))
-        topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
-        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
-        dynamic_ks = dynamic_ks.tolist()
-        for gt_idx in range(num_gt):
-            _, pos_idx = torch.topk(
-                cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
-            )
-            matching_matrix[gt_idx][pos_idx] = 1
-
-        del topk_ious, dynamic_ks, pos_idx
-
-        anchor_matching_gt = matching_matrix.sum(0)
-        if (anchor_matching_gt > 1).sum() > 0:
-            _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
-            matching_matrix[:, anchor_matching_gt > 1] *= 0
-            matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
-        fg_mask_inboxes = matching_matrix.sum(0) > 0
-
-        fg_mask[fg_mask.clone()] = fg_mask_inboxes
-
-        assigned_indexs = matching_matrix[:, fg_mask_inboxes].argmax(0)
-        assigned_labels = gt_classes[assigned_indexs]
-
-        assigned_ious = (matching_matrix * pair_wise_ious).sum(0)[
-            fg_mask_inboxes
-        ]
-        return assigned_labels, assigned_ious, assigned_indexs
+
+# -------------------------- Basic Functions --------------------------
+def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
+    """select the positive anchors's center in gt
+    Args:
+        xy_centers (Tensor): shape(bs*n_max_boxes, num_total_anchors, 4)
+        gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
+    Return:
+        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    n_anchors = xy_centers.size(0)
+    bs, n_max_boxes, _ = gt_bboxes.size()
+    _gt_bboxes = gt_bboxes.reshape([-1, 4])
+    xy_centers = xy_centers.unsqueeze(0).repeat(bs * n_max_boxes, 1, 1)
+    gt_bboxes_lt = _gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, n_anchors, 1)
+    gt_bboxes_rb = _gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, n_anchors, 1)
+    b_lt = xy_centers - gt_bboxes_lt
+    b_rb = gt_bboxes_rb - xy_centers
+    bbox_deltas = torch.cat([b_lt, b_rb], dim=-1)
+    bbox_deltas = bbox_deltas.reshape([bs, n_max_boxes, n_anchors, -1])
+    return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)
+
+def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
+    """if an anchor box is assigned to multiple gts,
+        the one with the highest iou will be selected.
+    Args:
+        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+        overlaps (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    Return:
+        target_gt_idx (Tensor): shape(bs, num_total_anchors)
+        fg_mask (Tensor): shape(bs, num_total_anchors)
+        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    fg_mask = mask_pos.sum(axis=-2)
+    if fg_mask.max() > 1:
+        mask_multi_gts = (fg_mask.unsqueeze(1) > 1).repeat([1, n_max_boxes, 1])
+        max_overlaps_idx = overlaps.argmax(axis=1)
+        is_max_overlaps = F.one_hot(max_overlaps_idx, n_max_boxes)
+        is_max_overlaps = is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype)
+        mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos)
+        fg_mask = mask_pos.sum(axis=-2)
+    target_gt_idx = mask_pos.argmax(axis=-2)
+    return target_gt_idx, fg_mask , mask_pos
+
+def iou_calculator(box1, box2, eps=1e-9):
+    """Calculate iou for batch
+    Args:
+        box1 (Tensor): shape(bs, n_max_boxes, 1, 4)
+        box2 (Tensor): shape(bs, 1, num_total_anchors, 4)
+    Return:
+        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    box1 = box1.unsqueeze(2)  # [N, M1, 4] -> [N, M1, 1, 4]
+    box2 = box2.unsqueeze(1)  # [N, M2, 4] -> [N, 1, M2, 4]
+    px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
+    gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
+    x1y1 = torch.maximum(px1y1, gx1y1)
+    x2y2 = torch.minimum(px2y2, gx2y2)
+    overlap = (x2y2 - x1y1).clip(0).prod(-1)
+    area1 = (px2y2 - px1y1).clip(0).prod(-1)
+    area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
+    union = area1 + area2 - overlap + eps
+
+    return overlap / union

+ 11 - 7
models/detectors/rtcdet/rtcdet.py

@@ -29,8 +29,9 @@ class RTCDet(nn.Module):
         # ---------------------- Basic Parameters ----------------------
         self.cfg = cfg
         self.device = device
-        self.stride = cfg['stride']
+        self.strides = cfg['stride']
         self.reg_max = cfg['reg_max']
+        self.num_levels = len(self.strides)
         self.num_classes = num_classes
         self.trainable = trainable
         self.conf_thresh = conf_thresh
@@ -53,14 +54,16 @@ class RTCDet(nn.Module):
         self.fpn_dims = self.fpn.out_dim
 
         ## ----------- Heads -----------
-        self.det_heads = build_det_head(
-            cfg, self.fpn_dims, self.head_dim, num_classes, num_levels=len(self.stride))
+        self.det_heads = build_det_head(cfg, self.fpn_dims, self.num_levels, num_classes, self.reg_max)
 
         ## ----------- Preds -----------
-        self.pred_layers = build_pred_layer(
-            self.det_heads.cls_head_dim, self.det_heads.reg_head_dim, self.stride,
-            num_classes=num_classes, num_coords=4, num_levels=len(self.stride), reg_max=self.reg_max)
-
+        self.pred_layers = build_pred_layer(cls_dim     = self.det_heads.cls_head_dim,
+                                            reg_dim     = self.det_heads.reg_head_dim,
+                                            strides     = self.strides,
+                                            num_classes = num_classes,
+                                            num_coords  = 4,
+                                            num_levels  = self.num_levels,
+                                            reg_max     = self.reg_max)
 
     ## post-process
     def post_process(self, cls_preds, box_preds):
@@ -155,6 +158,7 @@ class RTCDet(nn.Module):
             return bboxes, scores, labels
 
 
+    # ---------------------- Main Process for Training ----------------------
     def forward(self, x):
         if not self.trainable:
             return self.inference_single_image(x)

+ 19 - 9
models/detectors/rtcdet/rtcdet_head.py

@@ -59,6 +59,15 @@ class SingleLevelHead(nn.Module):
         self.cls_feats = nn.Sequential(*cls_feats)
         self.reg_feats = nn.Sequential(*reg_feats)
 
+        self.init_weights()
+        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
 
     def forward(self, x):
         """
@@ -72,14 +81,14 @@ class SingleLevelHead(nn.Module):
 
 # Multi-level Head
 class MultiLevelHead(nn.Module):
-    def __init__(self, cfg, in_dims, out_dim, num_classes=80, num_levels=3):
+    def __init__(self, cfg, in_dims, num_levels=3, num_classes=80, reg_max=16):
         super().__init__()
         ## ----------- Network Parameters -----------
         self.multi_level_heads = nn.ModuleList(
             [SingleLevelHead(
                 in_dims[level],
-                out_dim,            # cls head dim
-                out_dim,            # reg head dim
+                max(in_dims[0], min(num_classes, 100)), # cls head out_dim
+                max(in_dims[0]//4, 16, 4*reg_max),      # reg head out_dim
                 cfg['num_cls_head'],
                 cfg['num_reg_head'],
                 cfg['head_act'],
@@ -89,7 +98,6 @@ class MultiLevelHead(nn.Module):
             ])
         # --------- Basic Parameters ----------
         self.in_dims = in_dims
-        self.num_classes = num_classes
 
         self.cls_head_dim = self.multi_level_heads[0].cls_head_dim
         self.reg_head_dim = self.multi_level_heads[0].reg_head_dim
@@ -112,9 +120,9 @@ class MultiLevelHead(nn.Module):
     
 
 # build detection head
-def build_det_head(cfg, in_dim, out_dim, num_classes=80, num_levels=3):
+def build_det_head(cfg, in_dims, num_levels=3, num_classes=80, reg_max=16):
     if cfg['head'] == 'decoupled_head':
-        head = MultiLevelHead(cfg, in_dim, out_dim, num_classes, num_levels) 
+        head = MultiLevelHead(cfg, in_dims, num_levels, num_classes, reg_max)
 
     return head
 
@@ -131,10 +139,12 @@ if __name__ == '__main__':
         'head_depthwise': False,
         'reg_max': 16,
     }
-    fpn_dims = [256, 256, 256]
-    out_dim = 256
+    fpn_dims = [256, 512, 512]
+    cls_out_dim = 256
+    reg_out_dim = 64
     # Head-1
-    model = build_det_head(cfg, fpn_dims, out_dim, num_classes=80, reg_max=16, num_levels=3)
+    model = build_det_head(cfg, fpn_dims, num_levels=3, num_classes=80, reg_max=16)
+    print(model)
     fpn_feats = [torch.randn(1, fpn_dims[0], 80, 80), torch.randn(1, fpn_dims[1], 40, 40), torch.randn(1, fpn_dims[2], 20, 20)]
     t0 = time.time()
     outputs = model(fpn_feats)

+ 10 - 0
models/detectors/rtcdet/rtcdet_pafpn.py

@@ -45,6 +45,16 @@ class RTCDetPaFPN(nn.Module):
             self.out_layers = None
             self.out_dim = [round(256*cfg['width']), round(512*cfg['width']), round(1024*cfg['width'])]
 
+        self.init_weights()
+        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
+
 
     def forward(self, fpn_feats):
         c3, c4, c5 = fpn_feats

+ 12 - 19
models/detectors/rtcdet/rtcdet_pred.py

@@ -1,3 +1,4 @@
+import math
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
@@ -5,9 +6,10 @@ import torch.nn.functional as F
 
 # Single-level pred layer
 class SingleLevelPredLayer(nn.Module):
-    def __init__(self, cls_dim, reg_dim, num_classes, num_coords=4):
+    def __init__(self, cls_dim, reg_dim, stride, num_classes, num_coords=4):
         super().__init__()
         # --------- Basic Parameters ----------
+        self.stride = stride
         self.cls_dim = cls_dim
         self.reg_dim = reg_dim
         self.num_classes = num_classes
@@ -19,23 +21,15 @@ class SingleLevelPredLayer(nn.Module):
 
         self.init_bias()
         
-
     def init_bias(self):
-        # Init bias
-        init_prob = 0.01
-        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
-        # cls pred
+        # cls pred bias
         b = self.cls_pred.bias.view(1, -1)
-        b.data.fill_(bias_value.item())
+        b.data.fill_(math.log(5 / self.num_classes / (640. / self.stride) ** 2))
         self.cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        # reg pred
+        # reg pred bias
         b = self.reg_pred.bias.view(-1, )
         b.data.fill_(1.0)
         self.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        w = self.reg_pred.weight
-        w.data.fill_(0.)
-        self.reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
-
 
     def forward(self, cls_feat, reg_feat):
         """
@@ -66,15 +60,15 @@ class MultiLevelPredLayer(nn.Module):
             [SingleLevelPredLayer(
                 cls_dim,
                 reg_dim,
+                strides[l],
                 num_classes,
                 num_coords * self.reg_max)
-                for _ in range(num_levels)
+                for l in range(num_levels)
             ])
         ## proj conv
-        self.proj = nn.Parameter(torch.linspace(0, reg_max, reg_max), requires_grad=False)
-        self.proj_conv = nn.Conv2d(self.reg_max, 1, kernel_size=1, bias=False)
-        self.proj_conv.weight = nn.Parameter(self.proj.view([1, reg_max, 1, 1]).clone().detach(), requires_grad=False)
-
+        proj_init = torch.arange(reg_max, dtype=torch.float)
+        self.proj_conv = nn.Conv2d(self.reg_max, 1, kernel_size=1, bias=False).requires_grad_(False)
+        self.proj_conv.weight.data[:] = nn.Parameter(proj_init.view([1, reg_max, 1, 1]))
 
     def generate_anchors(self, level, fmp_size):
         """
@@ -90,7 +84,6 @@ class MultiLevelPredLayer(nn.Module):
 
         return anchors
         
-
     def forward(self, cls_feats, reg_feats):
         all_anchors = []
         all_strides = []
@@ -116,7 +109,7 @@ class MultiLevelPredLayer(nn.Module):
 
             # ----------------------- Decode bbox -----------------------
             B, M = reg_pred.shape[:2]
-            # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
+            # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max]
             delta_pred = reg_pred.reshape([B, M, 4, self.reg_max])
             # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
             delta_pred = delta_pred.permute(0, 3, 2, 1).contiguous()