Ver código fonte

modify yolov6

yjh0410 11 meses atrás
pai
commit
870ddf85ec

+ 10 - 22
yolo/config/yolov6_config.py

@@ -23,25 +23,13 @@ class Yolov6BaseConfig(object):
         self.out_stride = [8, 16, 32]
         self.max_stride = 32
         self.num_levels = 3
-        self.model_scale = "b"
+        self.model_scale = "l"
         ## Backbone
         self.use_pretrained = True
-        ## Neck
-        self.neck_act       = 'silu'
-        self.neck_norm      = 'BN'
-        self.neck_depthwise = False
-        self.neck_expand_ratio = 0.5
-        self.spp_pooling_size  = 5
-        ## FPN
-        self.fpn_act  = 'silu'
-        self.fpn_norm = 'BN'
-        self.fpn_depthwise = False
         ## Head
-        self.head_act  = 'silu'
-        self.head_norm = 'BN'
-        self.head_depthwise = False
-        self.num_cls_head   = 1
-        self.num_reg_head   = 1
+        self.head_dim       = 256
+        self.num_cls_head   = 2
+        self.num_reg_head   = 2
 
         # ---------------- Post-process config ----------------
         ## Post process
@@ -49,17 +37,17 @@ class Yolov6BaseConfig(object):
         self.val_conf_thresh = 0.001
         self.val_nms_thresh  = 0.7
         self.test_topk = 100
-        self.test_conf_thresh = 0.2
+        self.test_conf_thresh = 0.4
         self.test_nms_thresh  = 0.5
 
         # ---------------- Assignment config ----------------
         ## Matcher
-        self.tal_topk_candidates = 13
-        self.tal_alpha = 1.0
-        self.tal_beta  = 6.0
+        self.ota_center_sampling_radius = 2.5
+        self.ota_topk_candidate = 10
         ## Loss weight
+        self.loss_obj = 1.0
         self.loss_cls = 1.0
-        self.loss_box = 2.5
+        self.loss_box = 5.0
 
         # ---------------- ModelEMA config ----------------
         self.use_ema = True
@@ -88,7 +76,7 @@ class Yolov6BaseConfig(object):
         # ---------------- Data process config ----------------
         self.aug_type = 'yolo'
         self.mosaic_prob = 1.0
-        self.mixup_prob  = 0.15
+        self.mixup_prob  = 0.1
         self.copy_paste  = 0.0           # approximated by the YOLOX's mixup
         self.multi_scale = [0.5, 1.25]   # multi scale: [img_size * 0.5, img_size * 1.25]
         ## Pixel mean & std

+ 79 - 89
yolo/models/yolov6/loss.py

@@ -1,151 +1,141 @@
 import torch
 import torch.nn.functional as F
-
-from utils.box_ops import bbox_iou
+from utils.box_ops import get_ious
 from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
 
-from .matcher import TaskAlignedAssigner
+from .matcher import YoloxMatcher
 
 
 class SetCriterion(object):
     def __init__(self, cfg):
-        # --------------- Basic parameters ---------------
         self.cfg = cfg
-        self.reg_max = cfg.reg_max
         self.num_classes = cfg.num_classes
-        # --------------- Loss config ---------------
+        self.loss_obj_weight = cfg.loss_obj
         self.loss_cls_weight = cfg.loss_cls
         self.loss_box_weight = cfg.loss_box
-        # --------------- Matcher config ---------------
-        self.matcher = TaskAlignedAssigner(num_classes     = cfg.num_classes,
-                                           topk_candidates = cfg.tal_topk_candidates,
-                                           alpha           = cfg.tal_alpha,
-                                           beta            = cfg.tal_beta
-                                           )
-
-    def loss_classes(self, pred_cls, labels, scores):
-        # compute bce loss
-        alpha = 0.75
-        gamma = 2.0
-        # pred and target should be of the same size
-        bg_class_ind = pred_cls.shape[-1]
-        pos_inds = ((labels >= 0) & (labels < bg_class_ind)).nonzero().squeeze(1)
-
-        new_scores = pred_cls.new_zeros(pred_cls.shape)
-        pos_labels = labels[pos_inds]
-        new_scores[pos_inds, pos_labels] = scores[pos_inds].clone().detach()
-
-        pred_sigmoid = pred_cls.sigmoid()
-        focal_weight = new_scores * (new_scores > 0.0).float() + \
-            alpha * (pred_sigmoid - new_scores).abs().pow(gamma) * \
-            (new_scores <= 0.0).float()
-        
-        loss_cls = F.binary_cross_entropy_with_logits(
-            pred_cls, new_scores, reduction='none') * focal_weight
+        # matcher
+        self.matcher = YoloxMatcher(cfg.num_classes, cfg.ota_center_sampling_radius, cfg.ota_topk_candidate)
+
+    def loss_objectness(self, pred_obj, gt_obj):
+        loss_obj = F.binary_cross_entropy_with_logits(pred_obj, gt_obj, reduction='none')
+
+        return loss_obj
     
+    def loss_classes(self, pred_cls, gt_label):
+        loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_label, reduction='none')
+
         return loss_cls
-    
-    def loss_bboxes(self, pred_box, gt_box, bbox_weight):
+
+    def loss_bboxes(self, pred_box, gt_box):
         # regression loss
-        ious = bbox_iou(pred_box, gt_box, xywh=False, GIoU=True)
-        loss_box = (1.0 - ious.squeeze(-1)) * bbox_weight
+        ious = get_ious(pred_box, gt_box, "xyxy", 'giou')
+        loss_box = 1.0 - ious
 
         return loss_box
-    
+
     def __call__(self, outputs, targets):        
         """
+            outputs['pred_obj']: List(Tensor) [B, M, 1]
             outputs['pred_cls']: List(Tensor) [B, M, C]
-            outputs['pred_reg']: List(Tensor) [B, M, 4*(reg_max+1)]
+            outputs['pred_reg']: List(Tensor) [B, M, 4]
             outputs['pred_box']: List(Tensor) [B, M, 4]
-            outputs['anchors']: List(Tensor) [M, 2]
             outputs['strides']: List(Int) [8, 16, 32] output stride
-            outputs['stride_tensor']: List(Tensor) [M, 1]
             targets: (List) [dict{'boxes': [...], 
                                  'labels': [...], 
                                  'orig_size': ...}, ...]
         """
+        bs = outputs['pred_cls'][0].shape[0]
+        device = outputs['pred_cls'][0].device
+        fpn_strides = outputs['strides']
+        anchors = outputs['anchors']
         # preds: [B, M, C]
+        obj_preds = torch.cat(outputs['pred_obj'], dim=1)
         cls_preds = torch.cat(outputs['pred_cls'], dim=1)
         box_preds = torch.cat(outputs['pred_box'], dim=1)
-        bs, num_anchors = cls_preds.shape[:2]
-        device = cls_preds.device
-        anchors = torch.cat(outputs['anchors'], dim=0)
-        
-        # --------------- label assignment ---------------
-        gt_label_targets = []
-        gt_score_targets = []
-        gt_bbox_targets = []
+
+        # label assignment
+        cls_targets = []
+        box_targets = []
+        obj_targets = []
         fg_masks = []
+
         for batch_idx in range(bs):
-            tgt_labels = targets[batch_idx]["labels"].to(device)     # [Mp,]
-            tgt_boxs = targets[batch_idx]["boxes"].to(device)        # [Mp, 4]
+            tgt_labels = targets[batch_idx]["labels"].to(device)
+            tgt_bboxes = targets[batch_idx]["boxes"].to(device)
 
             # check target
-            if len(tgt_labels) == 0 or tgt_boxs.max().item() == 0.:
+            if len(tgt_labels) == 0 or tgt_bboxes.max().item() == 0.:
+                num_anchors = sum([ab.shape[0] for ab in anchors])
                 # There is no valid gt
-                fg_mask  = cls_preds.new_zeros(1, num_anchors).bool()                       # [1, M,]
-                gt_label = cls_preds.new_zeros((1, num_anchors)).long()                     # [1, M,]
-                gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)).float()  # [1, M, C]
-                gt_box   = cls_preds.new_zeros((1, num_anchors, 4)).float()                 # [1, M, 4]
+                cls_target = obj_preds.new_zeros((0, self.num_classes))
+                box_target = obj_preds.new_zeros((0, 4))
+                obj_target = obj_preds.new_zeros((num_anchors, 1))
+                fg_mask = obj_preds.new_zeros(num_anchors).bool()
             else:
-                tgt_labels = tgt_labels[None, :, None]      # [1, Mp, 1]
-                tgt_boxs = tgt_boxs[None]                   # [1, Mp, 4]
                 (
-                    gt_label,   # 
-                    gt_box,     # [1, M, 4]
-                    gt_score,   # [1, M, C]
-                    fg_mask,    # [1, M,]
-                    _
+                    fg_mask,
+                    assigned_labels,
+                    assigned_ious,
+                    assigned_indexs
                 ) = self.matcher(
-                    pd_scores = cls_preds[batch_idx:batch_idx+1].detach().sigmoid(), 
-                    pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
-                    anc_points = anchors,
-                    gt_labels = tgt_labels,
-                    gt_bboxes = tgt_boxs
+                    fpn_strides = fpn_strides,
+                    anchors = anchors,
+                    pred_obj = obj_preds[batch_idx],
+                    pred_cls = cls_preds[batch_idx], 
+                    pred_box = box_preds[batch_idx],
+                    tgt_labels = tgt_labels,
+                    tgt_bboxes = tgt_bboxes
                     )
-            gt_label_targets.append(gt_label)
-            gt_score_targets.append(gt_score)
-            gt_bbox_targets.append(gt_box)
+
+                obj_target = fg_mask.unsqueeze(-1)
+                cls_target = F.one_hot(assigned_labels.long(), self.num_classes)
+                cls_target = cls_target * assigned_ious.unsqueeze(-1)
+                box_target = tgt_bboxes[assigned_indexs]
+
+            cls_targets.append(cls_target)
+            box_targets.append(box_target)
+            obj_targets.append(obj_target)
             fg_masks.append(fg_mask)
 
-        # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
-        fg_masks = torch.cat(fg_masks, 0).view(-1)                                    # [BM,]
-        gt_label_targets = torch.cat(gt_label_targets, 0).view(-1)                    # [BM,]
-        gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes)  # [BM, C]
-        gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4)                   # [BM, 4]
-        num_fgs = gt_score_targets.sum()
+        cls_targets = torch.cat(cls_targets, 0)
+        box_targets = torch.cat(box_targets, 0)
+        obj_targets = torch.cat(obj_targets, 0)
+        fg_masks = torch.cat(fg_masks, 0)
+        num_fgs = fg_masks.sum()
 
-        # Average loss normalizer across all the GPUs
         if is_dist_avail_and_initialized():
             torch.distributed.all_reduce(num_fgs)
         num_fgs = (num_fgs / get_world_size()).clamp(1.0)
 
+        # ------------------ Objecntness loss ------------------
+        loss_obj = self.loss_objectness(obj_preds.view(-1, 1), obj_targets.float())
+        loss_obj = loss_obj.sum() / num_fgs
+        
         # ------------------ Classification loss ------------------
-        target_labels = torch.where(fg_masks > 0, gt_label_targets,
-                                    torch.full_like(gt_label_targets, self.num_classes))
-        target_scores = gt_score_targets.new_zeros(gt_score_targets.shape[0])
-        target_scores[fg_masks] = gt_score_targets[fg_masks, target_labels[fg_masks]]
-        cls_preds = cls_preds.view(-1, self.num_classes)
-        loss_cls = self.loss_classes(cls_preds, target_labels, target_scores)
+        cls_preds_pos = cls_preds.view(-1, self.num_classes)[fg_masks]
+        loss_cls = self.loss_classes(cls_preds_pos, cls_targets)
         loss_cls = loss_cls.sum() / num_fgs
 
         # ------------------ Regression loss ------------------
         box_preds_pos = box_preds.view(-1, 4)[fg_masks]
-        box_targets_pos = gt_bbox_targets.view(-1, 4)[fg_masks]
-        bbox_weight = gt_score_targets[fg_masks].sum(-1)
-        loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos, bbox_weight)
+        loss_box = self.loss_bboxes(box_preds_pos, box_targets)
         loss_box = loss_box.sum() / num_fgs
 
         # total loss
-        losses = loss_cls * self.loss_cls_weight + loss_box * self.loss_box_weight
+        losses = self.loss_obj_weight * loss_obj + \
+                 self.loss_cls_weight * loss_cls + \
+                 self.loss_box_weight * loss_box
+
+        # Loss dict
         loss_dict = dict(
+                loss_obj = loss_obj,
                 loss_cls = loss_cls,
                 loss_box = loss_box,
                 losses = losses
         )
 
         return loss_dict
-    
+
 
 if __name__ == "__main__":
-    pass
+    pass

+ 176 - 172
yolo/models/yolov6/matcher.py

@@ -1,176 +1,180 @@
 import torch
-import torch.nn as nn
-from utils.box_ops import bbox_iou
-
-
-# ------------------ Task Aligned Assigner ------------------
-class TaskAlignedAssigner(nn.Module):
-    def __init__(self,
-                 num_classes     = 80,
-                 topk_candidates = 10,
-                 alpha           = 0.5,
-                 beta            = 6.0, 
-                 eps             = 1e-9,
-                 ):
-        super(TaskAlignedAssigner, self).__init__()
-        self.topk_candidates = topk_candidates
+import torch.nn.functional as F
+from utils.box_ops import *
+
+
+class YoloxMatcher(object):
+    """
+        This code referenced to https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/models/yolo_head.py
+    """
+    def __init__(self, num_classes, center_sampling_radius, topk_candidate ):
         self.num_classes = num_classes
-        self.bg_idx = num_classes
-        self.alpha = alpha
-        self.beta = beta
-        self.eps = eps
+        self.center_sampling_radius = center_sampling_radius
+        self.topk_candidate = topk_candidate
+
 
     @torch.no_grad()
-    def forward(self,
-                pd_scores,
-                pd_bboxes,
-                anc_points,
-                gt_labels,
-                gt_bboxes):
-        self.bs = pd_scores.size(0)
-        self.n_max_boxes = gt_bboxes.size(1)
-
-        mask_pos, align_metric, overlaps = self.get_pos_mask(
-            pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points)
-
-        target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(
-            mask_pos, overlaps, self.n_max_boxes)
-
-        # Assigned target
-        target_labels, target_bboxes, target_scores = self.get_targets(
-            gt_labels, gt_bboxes, target_gt_idx, fg_mask)
-
-        # normalize
-        align_metric *= mask_pos
-        pos_align_metrics = align_metric.amax(axis=-1, keepdim=True)  # b, max_num_obj
-        pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True)  # b, max_num_obj
-        norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
-        target_scores = target_scores * norm_align_metric
-
-        return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
-
-    def select_highest_overlaps(self, mask_pos, overlaps, n_max_boxes):
-        """if an anchor box is assigned to multiple gts,
-            the one with the highest iou will be selected.
-        Args:
-            mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-            overlaps (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-        Return:
-            target_gt_idx (Tensor): shape(bs, num_total_anchors)
-            fg_mask (Tensor): shape(bs, num_total_anchors)
-            mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-        """
-        fg_mask = mask_pos.sum(-2)
-        if fg_mask.max() > 1:  # one anchor is assigned to multiple gt_bboxes
-            mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1)  # (b, n_max_boxes, h*w)
-            max_overlaps_idx = overlaps.argmax(1)  # (b, h*w)
-
-            is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
-            is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
-
-            mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()  # (b, n_max_boxes, h*w)
-            fg_mask = mask_pos.sum(-2)
-        # Find each grid serve which gt(index)
-        target_gt_idx = mask_pos.argmax(-2)  # (b, h*w)
-
-        return target_gt_idx, fg_mask, mask_pos
-
-    def select_candidates_in_gts(self, xy_centers, gt_bboxes, eps=1e-9):
-        """select the positive anchors's center in gt
-        Args:
-            xy_centers (Tensor): shape(bs*n_max_boxes, num_total_anchors, 4)
-            gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
-        Return:
-            (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-        """
-        n_anchors = xy_centers.size(0)
-        bs, n_max_boxes, _ = gt_bboxes.size()
-        _gt_bboxes = gt_bboxes.reshape([-1, 4])
-        xy_centers = xy_centers.unsqueeze(0).repeat(bs * n_max_boxes, 1, 1)
-        gt_bboxes_lt = _gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, n_anchors, 1)
-        gt_bboxes_rb = _gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, n_anchors, 1)
-        b_lt = xy_centers - gt_bboxes_lt
-        b_rb = gt_bboxes_rb - xy_centers
-        bbox_deltas = torch.cat([b_lt, b_rb], dim=-1)
-        bbox_deltas = bbox_deltas.reshape([bs, n_max_boxes, n_anchors, -1])
-        return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)
-
-    def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points):
-        # get in_gts mask, (b, max_num_obj, h*w)
-        mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
-        # get anchor_align metric, (b, max_num_obj, h*w)
-        align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts)
-        # get topk_metric mask, (b, max_num_obj, h*w)
-        mask_topk = self.select_topk_candidates(align_metric)
-        # merge all mask to a final mask, (b, max_num_obj, h*w)
-        mask_pos = mask_topk * mask_in_gts
-
-        return mask_pos, align_metric, overlaps
-
-    def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts):
-        """Compute alignment metric given predicted and ground truth bounding boxes."""
-        na = pd_bboxes.shape[-2]
-        mask_in_gts = mask_in_gts.bool()  # b, max_num_obj, h*w
-        overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
-        bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
-
-        ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
-        ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes)  # b, max_num_obj
-        ind[1] = gt_labels.squeeze(-1)  # b, max_num_obj
-        # Get the scores of each grid for each gt cls
-        bbox_scores[mask_in_gts] = pd_scores[ind[0], :, ind[1]][mask_in_gts]  # b, max_num_obj, h*w
-
-        # (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
-        pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_in_gts]
-        gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_in_gts]
-        overlaps[mask_in_gts] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
-
-        align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
-        return align_metric, overlaps
-
-    def select_topk_candidates(self, metrics, largest=True):
-        """
-        Args:
-            metrics: (b, max_num_obj, h*w).
-            topk_mask: (b, max_num_obj, topk) or None
-        """
-        # (b, max_num_obj, topk)
-        topk_metrics, topk_idxs = torch.topk(metrics, self.topk_candidates, dim=-1, largest=largest)
-        topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
-        # (b, max_num_obj, topk)
-        topk_idxs.masked_fill_(~topk_mask, 0)
-
-        # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
-        count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
-        ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
-        for k in range(self.topk_candidates):
-            # Expand topk_idxs for each value of k and add 1 at the specified positions
-            count_tensor.scatter_add_(-1, topk_idxs[:, :, k:k + 1], ones)
-        # count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
-        # Filter invalid bboxes
-        count_tensor.masked_fill_(count_tensor > 1, 0)
-
-        return count_tensor.to(metrics.dtype)
-
-    def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
-        # Assigned target labels, (b, 1)
-        batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
-        target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes  # (b, h*w)
-        target_labels = gt_labels.long().flatten()[target_gt_idx]  # (b, h*w)
-
-        # Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
-        target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
-
-        # Assigned target scores
-        target_labels.clamp_(0)
-
-        # 10x faster than F.one_hot()
-        target_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.num_classes),
-                                    dtype=torch.int64,
-                                    device=target_labels.device)  # (b, h*w, 80)
-        target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
-
-        fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)  # (b, h*w, 80)
-        target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
-
-        return target_labels, target_bboxes, target_scores
+    def __call__(self, 
+                 fpn_strides, 
+                 anchors, 
+                 pred_obj, 
+                 pred_cls, 
+                 pred_box, 
+                 tgt_labels,
+                 tgt_bboxes):
+        # [M,]
+        strides_tensor = torch.cat([torch.ones_like(anchor_i[:, 0]) * stride_i
+                                for stride_i, anchor_i in zip(fpn_strides, anchors)], dim=-1)
+        # List[F, M, 2] -> [M, 2]
+        anchors = torch.cat(anchors, dim=0)
+        num_anchor = anchors.shape[0]        
+        num_gt = len(tgt_labels)
+
+        # ----------------------- Find inside points -----------------------
+        fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(
+            tgt_bboxes, anchors, strides_tensor, num_anchor, num_gt)
+        obj_preds = pred_obj[fg_mask].float()   # [Mp, 1]
+        cls_preds = pred_cls[fg_mask].float()   # [Mp, C]
+        box_preds = pred_box[fg_mask].float()   # [Mp, 4]
+
+        # ----------------------- Reg cost -----------------------
+        pair_wise_ious, _ = box_iou(tgt_bboxes, box_preds)      # [N, Mp]
+        reg_cost = -torch.log(pair_wise_ious + 1e-8)            # [N, Mp]
+
+        # ----------------------- Cls cost -----------------------
+        with torch.cuda.amp.autocast(enabled=False):
+            # [Mp, C]
+            score_preds = torch.sqrt(obj_preds.sigmoid_()* cls_preds.sigmoid_())
+            # [N, Mp, C]
+            score_preds = score_preds.unsqueeze(0).repeat(num_gt, 1, 1)
+            # prepare cls_target
+            cls_targets = F.one_hot(tgt_labels.long(), self.num_classes).float()
+            cls_targets = cls_targets.unsqueeze(1).repeat(1, score_preds.size(1), 1)
+            # [N, Mp]
+            cls_cost = F.binary_cross_entropy(score_preds, cls_targets, reduction="none").sum(-1)
+        del score_preds
+
+        #----------------------- Dynamic K-Matching -----------------------
+        cost_matrix = (
+            cls_cost
+            + 3.0 * reg_cost
+            + 100000.0 * (~is_in_boxes_and_center)
+        ) # [N, Mp]
+
+        (
+            assigned_labels,         # [num_fg,]
+            assigned_ious,           # [num_fg,]
+            assigned_indexs,         # [num_fg,]
+        ) = self.dynamic_k_matching(
+            cost_matrix,
+            pair_wise_ious,
+            tgt_labels,
+            num_gt,
+            fg_mask
+            )
+        del cls_cost, cost_matrix, pair_wise_ious, reg_cost
+
+        return fg_mask, assigned_labels, assigned_ious, assigned_indexs
+
+    def get_in_boxes_info(
+        self,
+        gt_bboxes,   # [N, 4]
+        anchors,     # [M, 2]
+        strides,     # [M,]
+        num_anchors, # M
+        num_gt,      # N
+        ):
+        # anchor center
+        x_centers = anchors[:, 0]
+        y_centers = anchors[:, 1]
+
+        # [M,] -> [1, M] -> [N, M]
+        x_centers = x_centers.unsqueeze(0).repeat(num_gt, 1)
+        y_centers = y_centers.unsqueeze(0).repeat(num_gt, 1)
+
+        # [N,] -> [N, 1] -> [N, M]
+        gt_bboxes_l = gt_bboxes[:, 0].unsqueeze(1).repeat(1, num_anchors) # x1
+        gt_bboxes_t = gt_bboxes[:, 1].unsqueeze(1).repeat(1, num_anchors) # y1
+        gt_bboxes_r = gt_bboxes[:, 2].unsqueeze(1).repeat(1, num_anchors) # x2
+        gt_bboxes_b = gt_bboxes[:, 3].unsqueeze(1).repeat(1, num_anchors) # y2
+
+        b_l = x_centers - gt_bboxes_l
+        b_r = gt_bboxes_r - x_centers
+        b_t = y_centers - gt_bboxes_t
+        b_b = gt_bboxes_b - y_centers
+        bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
+
+        is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
+        is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
+        # in fixed center
+        center_radius = self.center_sampling_radius
+
+        # [N, 2]
+        gt_centers = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) * 0.5
+        
+        # [1, M]
+        center_radius_ = center_radius * strides.unsqueeze(0)
+
+        gt_bboxes_l = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # x1
+        gt_bboxes_t = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # y1
+        gt_bboxes_r = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # x2
+        gt_bboxes_b = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # y2
+
+        c_l = x_centers - gt_bboxes_l
+        c_r = gt_bboxes_r - x_centers
+        c_t = y_centers - gt_bboxes_t
+        c_b = gt_bboxes_b - y_centers
+        center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
+        is_in_centers = center_deltas.min(dim=-1).values > 0.0
+        is_in_centers_all = is_in_centers.sum(dim=0) > 0
+
+        # in boxes and in centers
+        is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
+
+        is_in_boxes_and_center = (
+            is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
+        )
+        return is_in_boxes_anchor, is_in_boxes_and_center
+
+    def dynamic_k_matching(
+        self, 
+        cost, 
+        pair_wise_ious, 
+        gt_classes, 
+        num_gt, 
+        fg_mask
+        ):
+        # Dynamic K
+        # ---------------------------------------------------------------
+        matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
+
+        ious_in_boxes_matrix = pair_wise_ious
+        n_candidate_k = min(self.topk_candidate, ious_in_boxes_matrix.size(1))
+        topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
+        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
+        dynamic_ks = dynamic_ks.tolist()
+        for gt_idx in range(num_gt):
+            _, pos_idx = torch.topk(
+                cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
+            )
+            matching_matrix[gt_idx][pos_idx] = 1
+
+        del topk_ious, dynamic_ks, pos_idx
+
+        anchor_matching_gt = matching_matrix.sum(0)
+        if (anchor_matching_gt > 1).sum() > 0:
+            _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
+            matching_matrix[:, anchor_matching_gt > 1] *= 0
+            matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
+        fg_mask_inboxes = matching_matrix.sum(0) > 0
+
+        fg_mask[fg_mask.clone()] = fg_mask_inboxes
+
+        assigned_indexs = matching_matrix[:, fg_mask_inboxes].argmax(0)
+        assigned_labels = gt_classes[assigned_indexs]
+
+        assigned_ious = (matching_matrix * pair_wise_ious).sum(0)[
+            fg_mask_inboxes
+        ]
+        return assigned_labels, assigned_ious, assigned_indexs
+    

+ 4 - 36
yolo/models/yolov6/modules.py

@@ -4,11 +4,6 @@ import torch.nn as nn
 
 
 # --------------------- Basic modules ---------------------
-def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
-    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
-
-    return conv
-
 def get_activation(act_type=None):
     if act_type == 'relu':
         return nn.ReLU(inplace=True)
@@ -23,49 +18,22 @@ def get_activation(act_type=None):
     else:
         raise NotImplementedError
 
-def get_norm(norm_type, dim):
-    if   norm_type == 'BN':
-        return nn.BatchNorm2d(dim)
-    elif norm_type == 'GN':
-        return nn.GroupNorm(num_groups=32, num_channels=dim)
-    elif norm_type is None:
-        return nn.Identity()
-    else:
-        raise NotImplementedError
-
 class ConvModule(nn.Module):
     def __init__(self, 
                  in_dim,                   # in channels
                  out_dim,                  # out channels 
-                 kernel_size=1,            # kernel size 
+                 kernel_size=1,            # kernel size
                  padding=0,                # padding
                  stride=1,                 # padding
-                 dilation=1,               # dilation
                  act_type  :str = 'lrelu', # activation
-                 norm_type :str = 'BN',    # normalization
-                 depthwise :bool = False
                 ):
         super(ConvModule, self).__init__()
-        self.depthwise = depthwise
-        if not depthwise:
-            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=1, bias=True)
-            self.norm = get_norm(norm_type, out_dim)
-        else:
-            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim, bias=True)
-            self.norm1 = get_norm(norm_type, in_dim)
-            self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1, bias=True)
-            self.norm2 = get_norm(norm_type, out_dim)
+        self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, padding=padding, stride=stride, bias=False)
+        self.norm = nn.BatchNorm2d(out_dim)
         self.act  = get_activation(act_type)
 
     def forward(self, x):
-        if not self.depthwise:
-            return self.act(self.norm(self.conv(x)))
-        else:
-            # Depthwise conv
-            x = self.norm1(self.conv1(x))
-            # Pointwise conv
-            x = self.act(self.norm2(self.conv2(x)))
-            return x
+        return self.act(self.norm(self.conv(x)))
 
 class RepVGGBlock(nn.Module):
     def __init__(self,

+ 10 - 6
yolo/models/yolov6/yolov6.py

@@ -41,17 +41,18 @@ class Yolov6(nn.Module):
         ## Head
         self.head     = Yolov6DetHead(cfg, self.fpn.out_dims)
         ## Pred
-        self.pred     = Yolov6DetPredLayer(cfg, self.fpn.out_dims)
+        self.pred     = Yolov6DetPredLayer(cfg,)
 
     def switch_deploy(self,):
         for m in self.modules():
             if hasattr(m, "switch_to_deploy"):
                 m.switch_to_deploy()
 
-    def post_process(self, cls_preds, box_preds):
+    def post_process(self, obj_preds, cls_preds, box_preds):
         """
         We process predictions at each scale hierarchically
         Input:
+            obj_preds: List[torch.Tensor] -> [[B, M, 1], ...], B=1
             cls_preds: List[torch.Tensor] -> [[B, M, C], ...], B=1
             box_preds: List[torch.Tensor] -> [[B, M, 4], ...], B=1
         Output:
@@ -63,12 +64,14 @@ class Yolov6(nn.Module):
         all_labels = []
         all_bboxes = []
         
-        for cls_pred_i, box_pred_i in zip(cls_preds, box_preds):
+        for obj_pred_i, cls_pred_i, box_pred_i in zip(obj_preds, cls_preds, box_preds):
+            obj_pred_i = obj_pred_i[0]
             cls_pred_i = cls_pred_i[0]
             box_pred_i = box_pred_i[0]
             if self.no_multi_labels:
                 # [M,]
-                scores, labels = torch.max(cls_pred_i.sigmoid(), dim=1)
+                scores, labels = torch.max(
+                    torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid()), dim=1)
 
                 # Keep top k top scoring indices only.
                 num_topk = min(self.topk_candidates, box_pred_i.size(0))
@@ -87,7 +90,7 @@ class Yolov6(nn.Module):
                 bboxes = box_pred_i[topk_idxs]
             else:
                 # [M, C] -> [MC,]
-                scores_i = cls_pred_i.sigmoid().flatten()
+                scores_i = torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid()).flatten()
 
                 # Keep top k top scoring indices only.
                 num_topk = min(self.topk_candidates, box_pred_i.size(0))
@@ -143,11 +146,12 @@ class Yolov6(nn.Module):
         outputs['image_size'] = [x.shape[2], x.shape[3]]
 
         if not self.training:
+            all_obj_preds = outputs['pred_obj']
             all_cls_preds = outputs['pred_cls']
             all_box_preds = outputs['pred_box']
 
             # post process
-            bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
+            bboxes, scores, labels = self.post_process(all_obj_preds, all_cls_preds, all_box_preds)
             outputs = {
                 "scores": scores,
                 "labels": labels,

+ 31 - 63
yolo/models/yolov6/yolov6_head.py

@@ -15,17 +15,12 @@ class DetHead(nn.Module):
                  reg_head_dim :int  = 256,
                  num_cls_head :int  = 2,
                  num_reg_head :int  = 2,
-                 act_type     :str  = "silu",
-                 norm_type    :str  = "BN",
-                 depthwise    :bool = False):
+                 ):
         super().__init__()
         # --------- Basic Parameters ----------
         self.in_dim = in_dim
         self.num_cls_head = num_cls_head
         self.num_reg_head = num_reg_head
-        self.act_type = act_type
-        self.norm_type = norm_type
-        self.depthwise = depthwise
         
         # --------- Network Parameters ----------
         ## cls head
@@ -33,54 +28,20 @@ class DetHead(nn.Module):
         self.cls_head_dim = cls_head_dim
         for i in range(num_cls_head):
             if i == 0:
-                cls_feats.append(
-                    ConvModule(in_dim, self.cls_head_dim,
-                              kernel_size=3, padding=1, stride=1, 
-                              act_type=act_type,
-                              norm_type=norm_type,
-                              depthwise=depthwise)
-                              )
+                cls_feats.append(ConvModule(in_dim, self.cls_head_dim, kernel_size=3, padding=1, stride=1))
             else:
-                cls_feats.append(
-                    ConvModule(self.cls_head_dim, self.cls_head_dim,
-                              kernel_size=3, padding=1, stride=1, 
-                              act_type=act_type,
-                              norm_type=norm_type,
-                              depthwise=depthwise)
-                              )
+                cls_feats.append(ConvModule(self.cls_head_dim, self.cls_head_dim, kernel_size=3, padding=1, stride=1))
         ## reg head
         reg_feats = []
         self.reg_head_dim = reg_head_dim
         for i in range(num_reg_head):
             if i == 0:
-                reg_feats.append(
-                    ConvModule(in_dim, self.reg_head_dim,
-                              kernel_size=3, padding=1, stride=1, 
-                              act_type=act_type,
-                              norm_type=norm_type,
-                              depthwise=depthwise)
-                              )
+                reg_feats.append(ConvModule(in_dim, self.reg_head_dim, kernel_size=3, padding=1, stride=1))
             else:
-                reg_feats.append(
-                    ConvModule(self.reg_head_dim, self.reg_head_dim,
-                              kernel_size=3, padding=1, stride=1, 
-                              act_type=act_type,
-                              norm_type=norm_type,
-                              depthwise=depthwise)
-                              )
+                reg_feats.append(ConvModule(self.reg_head_dim, self.reg_head_dim, kernel_size=3, padding=1, stride=1))
         self.cls_feats = nn.Sequential(*cls_feats)
         self.reg_feats = nn.Sequential(*reg_feats)
 
-        self.init_weights()
-        
-    def init_weights(self):
-        """Initialize the parameters."""
-        for m in self.modules():
-            if isinstance(m, torch.nn.Conv2d):
-                # In order to be consistent with the source code,
-                # reset the Conv2d initialization parameters
-                m.reset_parameters()
-
     def forward(self, x):
         """
             in_feats: (Tensor) [B, C, H, W]
@@ -97,17 +58,26 @@ class Yolov6DetHead(nn.Module):
         ## ----------- Network Parameters -----------
         self.multi_level_heads = nn.ModuleList(
             [DetHead(in_dim       = in_dims[level],
-                     cls_head_dim = in_dims[level],
-                     reg_head_dim = in_dims[level],
+                     cls_head_dim = round(cfg.head_dim * cfg.width),
+                     reg_head_dim = round(cfg.head_dim * cfg.width),
                      num_cls_head = cfg.num_cls_head,
                      num_reg_head = cfg.num_reg_head,
-                     act_type     = cfg.head_act,
-                     norm_type    = cfg.head_norm,
-                     depthwise    = cfg.head_depthwise)
-                     for level in range(cfg.num_levels)
-                     ])
+                     ) for level in range(cfg.num_levels)])
         # --------- Basic Parameters ----------
         self.in_dims = in_dims
+        self.cls_head_dim = cfg.head_dim
+        self.reg_head_dim = cfg.head_dim
+
+        # Initialize all layers
+        self.init_weights()
+
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
 
     def forward(self, feats):
         """
@@ -130,22 +100,21 @@ if __name__=='__main__':
     from thop import profile
     # Model config
     
-    # YOLOv3-Base config
-    class Yolov6BaseConfig(object):
+    # YOLOx-Base config
+    class YoloxBaseConfig(object):
         def __init__(self) -> None:
             # ---------------- Model config ----------------
-            self.out_stride = 32
+            self.width    = 0.50
+            self.depth    = 0.34
+            self.out_stride = [8, 16, 32]
             self.max_stride = 32
             self.num_levels = 3
             ## Head
-            self.head_act  = 'lrelu'
-            self.head_norm = 'BN'
-            self.head_depthwise = False
             self.head_dim  = 256
             self.num_cls_head   = 2
             self.num_reg_head   = 2
 
-    cfg = Yolov6BaseConfig()
+    cfg = YoloxBaseConfig()
     # Build a head
     pyramid_feats = [torch.randn(1, cfg.head_dim, 80, 80),
                      torch.randn(1, cfg.head_dim, 40, 40),
@@ -158,12 +127,11 @@ if __name__=='__main__':
     cls_feats, reg_feats = head(pyramid_feats)
     t1 = time.time()
     print('Time: ', t1 - t0)
-    for cls_f, reg_f in zip(cls_feats, reg_feats):
-        print(cls_f.shape, reg_f.shape)
+    print("====== Yolox Head output ======")
+    for level, (cls_f, reg_f) in enumerate(zip(cls_feats, reg_feats)):
+        print("- Level-{} : ".format(level), cls_f.shape, reg_f.shape)
 
-    print('==============================')
     flops, params = profile(head, inputs=(pyramid_feats, ), verbose=False)
     print('==============================')
     print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))  
-    
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 2 - 6
yolo/models/yolov6/yolov6_neck.py

@@ -18,12 +18,8 @@ class SPPF(nn.Module):
         inter_dim = in_dim // 2
         self.out_dim = out_dim
         ## ----------- Network Parameters -----------
-        self.cv1 = ConvModule(in_dim, inter_dim,
-                             kernel_size=1, padding=0, stride=1,
-                             act_type=cfg.neck_act, norm_type=cfg.neck_norm)
-        self.cv2 = ConvModule(inter_dim * 4, out_dim,
-                             kernel_size=1, padding=0, stride=1,
-                             act_type=cfg.neck_act, norm_type=cfg.neck_norm)
+        self.cv1 = ConvModule(in_dim, inter_dim, kernel_size=1, padding=0, stride=1, act_type="silu")
+        self.cv2 = ConvModule(inter_dim * 4, out_dim, kernel_size=1, padding=0, stride=1, act_type="silu")
         self.m = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
 
     def forward(self, x):

+ 18 - 18
yolo/models/yolov6/yolov6_pafpn.py

@@ -14,49 +14,48 @@ class Yolov6PaFPN(nn.Module):
     def __init__(self, cfg, in_dims: List = [256, 512, 1024]):
         super(Yolov6PaFPN, self).__init__()
         self.in_dims = in_dims
-        self.model_scale = cfg.scale
+        self.model_scale = cfg.model_scale
         c3, c4, c5 = in_dims
 
         # ---------------------- Yolov6's Top down FPN ----------------------
         ## P5 -> P4
         self.reduce_layer_1   = ConvModule(c5, round(256*cfg.width),
-                                          kernel_size=1, padding=0, stride=1,
-                                          act_type=cfg.fpn_act, norm_type=cfg.fpn_norm)
+                                           kernel_size=1, padding=0, stride=1, act_type="silu",)
         self.top_down_layer_1 = self.make_block(in_dim     = c4 + round(256*cfg.width),
                                                 out_dim    = round(256*cfg.width),
-                                                num_blocks = round(12*cfg.depth))
+                                                num_blocks = round(12*cfg.depth),
+                                                )
 
         ## P4 -> P3
         self.reduce_layer_2   = ConvModule(round(256*cfg.width), round(128*cfg.width),
-                                          kernel_size=1, padding=0, stride=1,
-                                          act_type=cfg.fpn_act, norm_type=cfg.fpn_norm)
+                                           kernel_size=1, padding=0, stride=1, act_type="silu",)
         self.top_down_layer_2 = self.make_block(in_dim     = c3 + round(128*cfg.width),
                                                 out_dim    = round(128*cfg.width),
-                                                num_blocks = round(12*cfg.depth))
+                                                num_blocks = round(12*cfg.depth),
+                                                )
         
         # ---------------------- Yolov6's Bottom up PAN ----------------------
         ## P3 -> P4
         self.downsample_layer_1 = ConvModule(round(128*cfg.width), round(128*cfg.width),
-                                            kernel_size=3, padding=1, stride=2,
-                                            act_type=cfg.fpn_act, norm_type=cfg.fpn_norm, depthwise=cfg.fpn_depthwise)
+                                             kernel_size=3, padding=1, stride=2, act_type="silu",)
         self.bottom_up_layer_1  = self.make_block(in_dim     = round(128*cfg.width) + round(128*cfg.width),
                                                   out_dim    = round(256*cfg.width),
-                                                  num_blocks = round(12*cfg.depth))
+                                                  num_blocks = round(12*cfg.depth),
+                                                  )
 
         ## P4 -> P5
         self.downsample_layer_2 = ConvModule(round(256*cfg.width), round(256*cfg.width),
-                                            kernel_size=3, padding=1, stride=2,
-                                            act_type=cfg.fpn_act, norm_type=cfg.fpn_norm, depthwise=cfg.fpn_depthwise)
+                                             kernel_size=3, padding=1, stride=2, act_type="silu",)
         self.bottom_up_layer_2  = self.make_block(in_dim     = round(256*cfg.width) + round(256*cfg.width),
                                                   out_dim    = round(512*cfg.width),
-                                                  num_blocks = round(12*cfg.depth))
+                                                  num_blocks = round(12*cfg.depth),
+                                                  )
 
         # ---------------------- Yolov6's output projection ----------------------
         self.out_layers = nn.ModuleList([
-            ConvModule(in_dim, in_dim, kernel_size=1,
-                      act_type=cfg.fpn_act, norm_type=cfg.fpn_norm)
-                      for in_dim in [round(128*cfg.width), round(256*cfg.width), round(512*cfg.width)]
-                      ])
+            ConvModule(in_dim, in_dim, kernel_size=1, act_type="silu",)
+                       for in_dim in [round(128*cfg.width), round(256*cfg.width), round(512*cfg.width)]
+                       ])
         self.out_dims = [round(128*cfg.width), round(256*cfg.width), round(512*cfg.width)]
 
     def make_block(self, in_dim, out_dim, num_blocks=1):
@@ -72,7 +71,8 @@ class Yolov6PaFPN(nn.Module):
         else:
             raise NotImplementedError("Unknown model scale: {}".format(self.model_scale))
             
-        return block        
+        return block      
+      
     def forward(self, features):
         c3, c4, c5 = features
         

+ 71 - 8
yolo/models/yolov6/yolov6_pred.py

@@ -19,6 +19,7 @@ class DetPredLayer(nn.Module):
         self.num_classes = num_classes
 
         # --------- Network Parameters ----------
+        self.obj_pred = nn.Conv2d(self.cls_dim, 1, kernel_size=1)
         self.cls_pred = nn.Conv2d(self.cls_dim, num_classes, kernel_size=1)
         self.reg_pred = nn.Conv2d(self.reg_dim, 4, kernel_size=1)                
 
@@ -28,6 +29,10 @@ class DetPredLayer(nn.Module):
         # Init bias
         init_prob = 0.01
         bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+        # obj pred
+        b = self.obj_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        self.obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
         # cls pred
         b = self.cls_pred.bias.view(1, -1)
         b.data.fill_(bias_value.item())
@@ -56,6 +61,7 @@ class DetPredLayer(nn.Module):
         
     def forward(self, cls_feat, reg_feat):
         # 预测层
+        obj_pred = self.obj_pred(reg_feat)
         cls_pred = self.cls_pred(cls_feat)
         reg_pred = self.reg_pred(reg_feat)
 
@@ -67,6 +73,7 @@ class DetPredLayer(nn.Module):
 
         # 对 pred 的size做一些view调整,便于后续的处理
         # [B, C, H, W] -> [B, H, W, C] -> [B, H*W, C]
+        obj_pred = obj_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 1)
         cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
         reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
         
@@ -78,7 +85,8 @@ class DetPredLayer(nn.Module):
         box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
 
         # output dict
-        outputs = {"pred_cls": cls_pred,       # (torch.Tensor) [B, M, C]
+        outputs = {"pred_obj": obj_pred,       # (torch.Tensor) [B, M, 1]
+                   "pred_cls": cls_pred,       # (torch.Tensor) [B, M, C]
                    "pred_reg": reg_pred,       # (torch.Tensor) [B, M, 4]
                    "pred_box": box_pred,       # (torch.Tensor) [B, M, 4]
                    "anchors" : anchors,        # (torch.Tensor) [M, 2]
@@ -90,32 +98,35 @@ class DetPredLayer(nn.Module):
 
 ## Multi-level pred layer
 class Yolov6DetPredLayer(nn.Module):
-    def __init__(self, cfg, in_dims):
+    def __init__(self, cfg):
         super().__init__()
         # --------- Basic Parameters ----------
         self.cfg = cfg
+        self.num_levels = len(cfg.out_stride)
 
         # ----------- Network Parameters -----------
         ## pred layers
         self.multi_level_preds = nn.ModuleList(
-            [DetPredLayer(cls_dim      = in_dims[level],
-                          reg_dim      = in_dims[level],
+            [DetPredLayer(cls_dim      = round(cfg.head_dim * cfg.width),
+                          reg_dim      = round(cfg.head_dim * cfg.width),
                           stride       = cfg.out_stride[level],
                           num_classes  = cfg.num_classes,)
-                          for level in range(cfg.num_levels)
+                          for level in range(self.num_levels)
                           ])
 
     def forward(self, cls_feats, reg_feats):
         all_anchors = []
         all_fmp_sizes = []
+        all_obj_preds = []
         all_cls_preds = []
         all_reg_preds = []
         all_box_preds = []
-        for level in range(self.cfg.num_levels):
+        for level in range(self.num_levels):
             # -------------- Single-level prediction --------------
             outputs = self.multi_level_preds[level](cls_feats[level], reg_feats[level])
 
             # collect results
+            all_obj_preds.append(outputs["pred_obj"])
             all_cls_preds.append(outputs["pred_cls"])
             all_reg_preds.append(outputs["pred_reg"])
             all_box_preds.append(outputs["pred_box"])
@@ -123,7 +134,8 @@ class Yolov6DetPredLayer(nn.Module):
             all_anchors.append(outputs["anchors"])
         
         # output dict
-        outputs = {"pred_cls":  all_cls_preds,         # List(Tensor) [B, M, C]
+        outputs = {"pred_obj":  all_obj_preds,         # List(Tensor) [B, M, 1]
+                   "pred_cls":  all_cls_preds,         # List(Tensor) [B, M, C]
                    "pred_reg":  all_reg_preds,         # List(Tensor) [B, M, 4*(reg_max)]
                    "pred_box":  all_box_preds,         # List(Tensor) [B, M, 4]
                    "fmp_sizes": all_fmp_sizes,         # List(Tensor) [M, 1]
@@ -132,4 +144,55 @@ class Yolov6DetPredLayer(nn.Module):
                    }
 
         return outputs
-    
+
+
+if __name__=='__main__':
+    import time
+    from thop import profile
+    # Model config
+    
+    # YOLOv6-Base config
+    class Yolov6BaseConfig(object):
+        def __init__(self) -> None:
+            # ---------------- Model config ----------------
+            self.width    = 1.0
+            self.depth    = 1.0
+            self.out_stride = [8, 16, 32]
+            self.max_stride = 32
+            ## Head
+            self.head_dim  = 256
+
+    cfg = Yolov6BaseConfig()
+    cfg.num_classes = 20
+    # Build a pred layer
+    pred = Yolov6DetPredLayer(cfg)
+
+    # Inference
+    cls_feats = [torch.randn(1, cfg.head_dim, 80, 80),
+                 torch.randn(1, cfg.head_dim, 40, 40),
+                 torch.randn(1, cfg.head_dim, 20, 20),]
+    reg_feats = [torch.randn(1, cfg.head_dim, 80, 80),
+                 torch.randn(1, cfg.head_dim, 40, 40),
+                 torch.randn(1, cfg.head_dim, 20, 20),]
+    t0 = time.time()
+    output = pred(cls_feats, reg_feats)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print('====== Pred output ======= ')
+    pred_obj = output["pred_obj"]
+    pred_cls = output["pred_cls"]
+    pred_reg = output["pred_reg"]
+    pred_box = output["pred_box"]
+    anchors  = output["anchors"]
+    
+    for level in range(len(cfg.out_stride)):
+        print("- Level-{} : objectness       -> {}".format(level, pred_obj[level].shape))
+        print("- Level-{} : classification   -> {}".format(level, pred_cls[level].shape))
+        print("- Level-{} : delta regression -> {}".format(level, pred_reg[level].shape))
+        print("- Level-{} : bbox regression  -> {}".format(level, pred_box[level].shape))
+        print("- Level-{} : anchor boxes     -> {}".format(level, anchors[level].shape))
+
+    flops, params = profile(pred, inputs=(cls_feats, reg_feats, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))