yjh0410 1 rok temu
rodzic
commit
6eaf398988

+ 27 - 13
odlab/config/fcos_config.py

@@ -5,10 +5,10 @@ def build_fcos_config(args):
         return Fcos_R18_1x_Config()
     elif args.model == 'fcos_r50_1x':
         return Fcos_R50_1x_Config()
-    elif args.model == 'fcos_rt_r18_1x':
-        return FcosRT_R18_1x_Config()
-    elif args.model == 'fcos_rt_r50_1x':
-        return FcosRT_R50_1x_Config()
+    elif args.model == 'fcos_rt_r18_3x':
+        return FcosRT_R18_3x_Config()
+    elif args.model == 'fcos_rt_r50_3x':
+        return FcosRT_R50_3x_Config()
     else:
         raise NotImplementedError("No config for model: {}".format(args.model))
 
@@ -118,7 +118,7 @@ class Fcos_R50_1x_Config(FcosBaseConfig):
         ## Backbone
         self.backbone = "resnet50"
 
-class FcosRT_R18_1x_Config(FcosBaseConfig):
+class FcosRT_R18_3x_Config(FcosBaseConfig):
     def __init__(self) -> None:
         super().__init__()
         ## Backbone
@@ -132,20 +132,27 @@ class FcosRT_R18_1x_Config(FcosBaseConfig):
         self.fpn_p7_feat = False
         self.fpn_p6_from_c5  = False
 
+        # --------- Head ---------
+        self.head = 'fcos_rt_head'
+        self.head_dim = 256
+        self.num_cls_head = 4
+        self.num_reg_head = 4
+        self.head_act     = 'relu'
+        self.head_norm    = 'GN'
+
         # --------- Label Assignment ---------
         self.matcher = 'simota'
-        self.matcher_hpy = {'soft_center_radius': 2.5,
-                            'topk_candidates': 13},
+        self.matcher_hpy = {'soft_center_radius': 3.0,
+                            'topk_candidates': 13}
 
         # --------- Loss weight ---------
         self.focal_loss_alpha = 0.25
         self.focal_loss_gamma = 2.0
         self.loss_cls_weight  = 1.0
         self.loss_reg_weight  = 2.0
-        self.loss_ctn_weight  = 0.5
 
         # --------- Train epoch ---------
-        self.max_epoch = 36,        # 3x
+        self.max_epoch = 36         # 3x
         self.lr_epoch  = [24, 33]   # 3x
 
         # --------- Data process ---------
@@ -166,7 +173,7 @@ class FcosRT_R18_1x_Config(FcosBaseConfig):
             {'name': 'RandomResize'},
         ]
 
-class FcosRT_R50_1x_Config(FcosBaseConfig):
+class FcosRT_R50_3x_Config(FcosBaseConfig):
     def __init__(self) -> None:
         super().__init__()
         ## Backbone
@@ -180,20 +187,27 @@ class FcosRT_R50_1x_Config(FcosBaseConfig):
         self.fpn_p7_feat = False
         self.fpn_p6_from_c5  = False
 
+        # --------- Head ---------
+        self.head = 'fcos_rt_head'
+        self.head_dim = 256
+        self.num_cls_head = 4
+        self.num_reg_head = 4
+        self.head_act     = 'relu'
+        self.head_norm    = 'GN'
+
         # --------- Label Assignment ---------
         self.matcher = 'simota'
         self.matcher_hpy = {'soft_center_radius': 2.5,
-                            'topk_candidates': 13},
+                            'topk_candidates': 13}
 
         # --------- Loss weight ---------
         self.focal_loss_alpha = 0.25
         self.focal_loss_gamma = 2.0
         self.loss_cls_weight  = 1.0
         self.loss_reg_weight  = 2.0
-        self.loss_ctn_weight  = 0.5
 
         # --------- Train epoch ---------
-        self.max_epoch = 36,        # 3x
+        self.max_epoch = 36         # 3x
         self.lr_epoch  = [24, 33]   # 3x
 
         # --------- Data process ---------

+ 5 - 2
odlab/models/detectors/__init__.py

@@ -1,14 +1,17 @@
 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
 import torch
 
-from .fcos.build  import build_fcos
+from .fcos.build  import build_fcos, build_fcos_rt
 from .yolof.build import build_yolof
 
 
 def build_model(args, cfg, is_val=False):
     # ------------ build object detector ------------
+    ## RT-FCOS    
+    if   'fcos_rt' in args.model:
+        model, criterion = build_fcos_rt(cfg, is_val)
     ## FCOS    
-    if 'fcos' in args.model:
+    elif 'fcos' in args.model:
         model, criterion = build_fcos(cfg, is_val)
     ## YOLOF    
     elif 'yolof' in args.model:

+ 19 - 1
odlab/models/detectors/fcos/build.py

@@ -2,7 +2,7 @@
 # -*- coding:utf-8 -*-
 
 from .criterion import SetCriterion
-from .fcos import FCOS
+from .fcos import FCOS, FcosRT
 
 
 # build FCOS
@@ -21,4 +21,22 @@ def build_fcos(cfg, is_val=False):
         # build criterion for training
         criterion = SetCriterion(cfg)
 
+    return model, criterion
+
+# build FCOS
+def build_fcos_rt(cfg, is_val=False):
+    # -------------- Build FCOS --------------
+    model = FcosRT(cfg         = cfg,
+                   num_classes = cfg.num_classes,
+                   conf_thresh = cfg.train_conf_thresh if is_val else cfg.test_conf_thresh,
+                   nms_thresh  = cfg.train_nms_thresh  if is_val else cfg.test_nms_thresh,
+                   topk        = cfg.train_topk        if is_val else cfg.test_topk,
+                   )
+            
+    # -------------- Build Criterion --------------
+    criterion = None
+    if is_val:
+        # build criterion for training
+        criterion = SetCriterion(cfg)
+
     return model, criterion

+ 63 - 45
odlab/models/detectors/fcos/criterion.py

@@ -19,18 +19,20 @@ class SetCriterion(nn.Module):
         self.alpha = cfg.focal_loss_alpha
         self.gamma = cfg.focal_loss_gamma
         # ------------- Loss weight -------------
-        self.weight_dict = {'loss_cls': cfg.loss_cls_weight,
-                            'loss_reg': cfg.loss_reg_weight,
-                            'loss_ctn': cfg.loss_ctn_weight}
-        # ------------- Matcher -------------
+        # ------------- Matcher & Loss weight -------------
         self.matcher_cfg = cfg.matcher_hpy
         if cfg.matcher == 'fcos_matcher':
+            self.weight_dict = {'loss_cls': cfg.loss_cls_weight,
+                                'loss_reg': cfg.loss_reg_weight,
+                                'loss_ctn': cfg.loss_ctn_weight}
             self.matcher = FcosMatcher(cfg.num_classes,
                                        self.matcher_cfg['center_sampling_radius'],
                                        self.matcher_cfg['object_sizes_of_interest'],
                                        [1., 1., 1., 1.]
                                        )
         elif cfg.matcher == 'simota':
+            self.weight_dict = {'loss_cls': cfg.loss_cls_weight,
+                                'loss_reg': cfg.loss_reg_weight}
             self.matcher = SimOtaMatcher(cfg.num_classes,
                                          self.matcher_cfg['soft_center_radius'],
                                          self.matcher_cfg['topk_candidates'])
@@ -47,6 +49,33 @@ class SetCriterion(nn.Module):
 
         return loss_cls.sum() / num_boxes
 
+    def loss_labels_qfl(self, pred_cls, target, beta=2.0, num_boxes=1.0):
+        # Quality FocalLoss
+        """
+            pred_cls: (torch.Tensor): [N, C]。
+            target:   (tuple([torch.Tensor], [torch.Tensor])): label -> (N,), score -> (N)
+        """
+        label, score = target
+        pred_sigmoid = pred_cls.sigmoid()
+        scale_factor = pred_sigmoid
+        zerolabel = scale_factor.new_zeros(pred_cls.shape)
+
+        ce_loss = F.binary_cross_entropy_with_logits(
+            pred_cls, zerolabel, reduction='none') * scale_factor.pow(beta)
+        
+        bg_class_ind = pred_cls.shape[-1]
+        pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
+        if pos.shape[0] > 0:
+            pos_label = label[pos].long()
+
+            scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
+
+            ce_loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
+                pred_cls[pos, pos_label], score[pos],
+                reduction='none') * scale_factor.abs().pow(beta)
+
+        return ce_loss.sum() / num_boxes
+    
     def loss_bboxes_ltrb(self, pred_delta, tgt_delta, bbox_quality=None, num_boxes=1.0):
         """
             pred_box: (Tensor) [N, 4]
@@ -157,27 +186,27 @@ class SetCriterion(nn.Module):
             outputs['pred_cls']: (Tensor) [B, M, C]
             outputs['pred_reg']: (Tensor) [B, M, 4]
             outputs['pred_box']: (Tensor) [B, M, 4]
-            outputs['pred_ctn']: (Tensor) [B, M, 1]
             outputs['strides']: (List) [8, 16, 32, ...] stride of the model output
             targets: (List) [dict{'boxes': [...], 
                                  'labels': [...], 
                                  'orig_size': ...}, ...]
         """
         # -------------------- Pre-process --------------------
-        device = outputs['pred_cls'][0].device
-        batch_size =  outputs['pred_cls'][0].shape[0]
+        bs          = outputs['pred_cls'][0].shape[0]
+        device      = outputs['pred_cls'][0].device
         fpn_strides = outputs['strides']
-        anchors = outputs['anchors']
-        pred_cls = torch.cat(outputs['pred_cls'], dim=1)   # [B, M, C]
-        pred_box = torch.cat(outputs['pred_box'], dim=1)   # [B, M, 4]
-        pred_ctn = torch.cat(outputs['pred_ctn'], dim=1)   # [B, M, 1]
+        anchors     = outputs['anchors']
+        # preds: [B, M, C]
+        # preds: [B, M, C]
+        cls_preds = torch.cat(outputs['pred_cls'], dim=1)
+        box_preds = torch.cat(outputs['pred_box'], dim=1)
         masks = ~torch.cat(outputs['mask'], dim=1).view(-1)
 
         # -------------------- Label Assignment --------------------
-        gt_classes = []
-        gt_bboxes = []
-        gt_centerness = []
-        for batch_idx in range(batch_size):
+        cls_targets = []
+        box_targets = []
+        assign_metrics = []
+        for batch_idx in range(bs):
             tgt_labels = targets[batch_idx]["labels"].to(device)  # [N,]
             tgt_bboxes = targets[batch_idx]["boxes"].to(device)   # [N, 4]
             # refine target
@@ -189,52 +218,41 @@ class SetCriterion(nn.Module):
             # label assignment
             assigned_result = self.matcher(fpn_strides=fpn_strides,
                                            anchors=anchors,
-                                           pred_cls=pred_cls[batch_idx].detach(),
-                                           pred_box=pred_box[batch_idx].detach(),
-                                           pred_iou=pred_ctn[batch_idx].detach(),
+                                           pred_cls=cls_preds[batch_idx].detach(),
+                                           pred_box=box_preds[batch_idx].detach(),
                                            gt_labels=tgt_labels,
                                            gt_bboxes=tgt_bboxes
                                            )
-            gt_classes.append(assigned_result['assigned_labels'])
-            gt_bboxes.append(assigned_result['assigned_bboxes'])
-            gt_centerness.append(assigned_result['assign_metrics'])
+            cls_targets.append(assigned_result['assigned_labels'])
+            box_targets.append(assigned_result['assigned_bboxes'])
+            assign_metrics.append(assigned_result['assign_metrics'])
 
         # List[B, M, C] -> Tensor[BM, C]
-        gt_classes = torch.cat(gt_classes, dim=0)         # [BM,]
-        gt_bboxes = torch.cat(gt_bboxes, dim=0)           # [BM, 4]
-        gt_centerness = torch.cat(gt_centerness, dim=0)   # [BM,]
+        cls_targets = torch.cat(cls_targets, dim=0)
+        box_targets = torch.cat(box_targets, dim=0)
+        assign_metrics = torch.cat(assign_metrics, dim=0)
 
-        valid_idxs = (gt_classes >= 0) & masks
-        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
-        num_foreground = foreground_idxs.sum()
+        valid_idxs = (cls_targets >= 0) & masks
+        foreground_idxs = (cls_targets >= 0) & (cls_targets != self.num_classes)
+        num_fgs = assign_metrics.sum()
 
         if is_dist_avail_and_initialized():
-            torch.distributed.all_reduce(num_foreground)
-        num_foreground = torch.clamp(num_foreground / get_world_size(), min=1).item()
+            torch.distributed.all_reduce(num_fgs)
+        num_fgs = torch.clamp(num_fgs / get_world_size(), min=1).item()
 
         # -------------------- classification loss --------------------
-        pred_cls = pred_cls.view(-1, self.num_classes)
-        gt_classes_target = torch.zeros_like(pred_cls)
-        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1
-        loss_labels = self.loss_labels(pred_cls[valid_idxs], gt_classes_target[valid_idxs], num_foreground)
+        cls_preds = cls_preds.view(-1, self.num_classes)[valid_idxs]
+        qfl_targets = (cls_targets[valid_idxs], assign_metrics[valid_idxs])
+        loss_labels = self.loss_labels_qfl(cls_preds, qfl_targets, 2.0, num_fgs)
 
         # -------------------- regression loss --------------------
-        pred_box = pred_box.view(-1, 4)
-        pred_box_pos = pred_box[foreground_idxs]
-        gt_box_pos = gt_bboxes[foreground_idxs]
-        loss_bboxes = self.loss_bboxes_xyxy(pred_box_pos, gt_box_pos, num_foreground)
-
-        # -------------------- centerness loss --------------------
-        pred_ctn = pred_ctn.view(-1)
-        pred_ctn_pos = pred_ctn[foreground_idxs]
-        gt_ctn_pos = gt_centerness[foreground_idxs]
-        loss_centerness = F.binary_cross_entropy_with_logits(pred_ctn_pos, gt_ctn_pos, reduction='none')
-        loss_centerness = loss_centerness.sum() / num_foreground
+        box_preds_pos = box_preds.view(-1, 4)[foreground_idxs]
+        box_targets_pos = box_targets[foreground_idxs]
+        loss_bboxes = self.loss_bboxes_xyxy(box_preds_pos, box_targets_pos, num_fgs)
 
         loss_dict = dict(
                 loss_cls = loss_labels,
                 loss_reg = loss_bboxes,
-                loss_ctn = loss_centerness,
         )
 
         return loss_dict

+ 112 - 0
odlab/models/detectors/fcos/fcos.py

@@ -124,3 +124,115 @@ class FCOS(nn.Module):
             }
 
         return outputs 
+
+# ------------------------ Real-time FCOS ------------------------
+class FcosRT(nn.Module):
+    def __init__(self, 
+                 cfg,
+                 num_classes :int   = 80, 
+                 conf_thresh :float = 0.05,
+                 nms_thresh  :float = 0.6,
+                 topk        :int   = 1000,
+                 ca_nms      :bool  = False):
+        super(FcosRT, self).__init__()
+        # ---------------------- Basic Parameters ----------------------
+        self.cfg = cfg
+        self.topk = topk
+        self.num_classes = num_classes
+        self.conf_thresh = conf_thresh
+        self.nms_thresh = nms_thresh
+        self.ca_nms = ca_nms
+
+        # ---------------------- Network Parameters ----------------------
+        ## Backbone
+        self.backbone, feat_dims = build_backbone(cfg)
+
+        ## Neck
+        self.fpn = build_neck(cfg, feat_dims, cfg.head_dim)
+        
+        ## Heads
+        self.head = build_head(cfg, cfg.head_dim, cfg.head_dim)
+
+    def post_process(self, cls_preds, box_preds):
+        """
+        Input:
+            cls_preds: List(Tensor) [[B, H x W, C], ...]
+            box_preds: List(Tensor) [[B, H x W, 4], ...]
+        """
+        all_scores = []
+        all_labels = []
+        all_bboxes = []
+        
+        for cls_pred_i, box_pred_i in zip(cls_preds, box_preds):
+            cls_pred_i = cls_pred_i[0]
+            box_pred_i = box_pred_i[0]
+            
+            # (H x W x C,)
+            scores_i = cls_pred_i.sigmoid().flatten()
+
+            # Keep top k top scoring indices only.
+            num_topk = min(self.topk, box_pred_i.size(0))
+
+            # torch.sort is actually faster than .topk (at least on GPUs)
+            predicted_prob, topk_idxs = scores_i.sort(descending=True)
+            topk_scores = predicted_prob[:num_topk]
+            topk_idxs = topk_idxs[:num_topk]
+
+            # filter out the proposals with low confidence score
+            keep_idxs = topk_scores > self.conf_thresh
+            topk_idxs = topk_idxs[keep_idxs]
+
+            # final scores
+            scores = topk_scores[keep_idxs]
+            # final labels
+            labels = topk_idxs % self.num_classes
+            # final bboxes
+            anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+            bboxes = box_pred_i[anchor_idxs]
+
+            all_scores.append(scores)
+            all_labels.append(labels)
+            all_bboxes.append(bboxes)
+
+        scores = torch.cat(all_scores)
+        labels = torch.cat(all_labels)
+        bboxes = torch.cat(all_bboxes)
+
+        # to cpu & numpy
+        scores = scores.cpu().numpy()
+        labels = labels.cpu().numpy()
+        bboxes = bboxes.cpu().numpy()
+
+        # nms
+        scores, labels, bboxes = multiclass_nms(
+            scores, labels, bboxes, self.nms_thresh, self.num_classes, self.ca_nms)
+
+        return bboxes, scores, labels
+
+    def forward(self, src, src_mask=None):
+        # ---------------- Backbone ----------------
+        pyramid_feats = self.backbone(src)
+
+        # ---------------- Neck ----------------
+        pyramid_feats = self.fpn(pyramid_feats)
+
+        # ---------------- Heads ----------------
+        outputs = self.head(pyramid_feats, src_mask)
+
+        if not self.training:
+            # ---------------- PostProcess ----------------
+            cls_pred = outputs["pred_cls"]
+            box_pred = outputs["pred_box"]
+            bboxes, scores, labels = self.post_process(cls_pred, box_pred)
+            # normalize bbox
+            bboxes[..., 0::2] /= src.shape[-1]
+            bboxes[..., 1::2] /= src.shape[-2]
+            bboxes = bboxes.clip(0., 1.)
+
+            outputs = {
+                'scores': scores,
+                'labels': labels,
+                'bboxes': bboxes
+            }
+
+        return outputs 

+ 4 - 6
odlab/models/detectors/fcos/matcher.py

@@ -238,7 +238,6 @@ class SimOtaMatcher(object):
                  anchors, 
                  pred_cls, 
                  pred_box,
-                 pred_iou,
                  gt_labels,
                  gt_bboxes):
         # [M,]
@@ -275,13 +274,12 @@ class SimOtaMatcher(object):
 
         # ----------------------------------- classification cost -----------------------------------
         ## select the predicted scores corresponded to the gt_labels
-        pred_scores = torch.sqrt(pred_cls.sigmoid() * pred_iou.sigmoid())
-        pred_scores = pred_scores.permute(1, 0)  # [M, C] -> [C, M]
-        pairwise_pred_scores = pred_scores[gt_labels.long(), :].float()   # [N, M]
+        pairwise_pred_scores = pred_cls.permute(1, 0)  # [M, C] -> [C, M]
+        pairwise_pred_scores = pairwise_pred_scores[gt_labels.long(), :].float()   # [N, M]
         ## scale factor
-        scale_factor = (pair_wise_ious - pairwise_pred_scores).abs().pow(2.0)
+        scale_factor = (pair_wise_ious - pairwise_pred_scores.sigmoid()).abs().pow(2.0)
         ## cls cost
-        pair_wise_cls_loss = F.binary_cross_entropy(
+        pair_wise_cls_loss = F.binary_cross_entropy_with_logits(
             pairwise_pred_scores, pair_wise_ious,
             reduction="none") * scale_factor # [N, M]
             

+ 4 - 2
odlab/models/head/__init__.py

@@ -1,5 +1,5 @@
 from .yolof_head     import YolofHead
-from .fcos_head      import FcosHead
+from .fcos_head      import FcosHead, FcosRTHead
 
 
 # build head
@@ -7,8 +7,10 @@ def build_head(cfg, in_dim, out_dim):
     print('==============================')
     print('Head: {}'.format(cfg.head))
     
-    if cfg.head == 'fcos_head':
+    if   cfg.head == 'fcos_head':
         model = FcosHead(cfg, in_dim, out_dim)
+    elif cfg.head == 'fcos_rt_head':
+        model = FcosRTHead(cfg, in_dim, out_dim)
     elif cfg.head == 'yolof_head':
         model = YolofHead(cfg, in_dim, out_dim)
 

+ 150 - 0
odlab/models/head/fcos_head.py

@@ -184,3 +184,153 @@ class FcosHead(nn.Module):
                    "mask": all_masks}          # List [B, M,]
 
         return outputs 
+
+class FcosRTHead(nn.Module):
+    def __init__(self, cfg, in_dim, out_dim,):
+        super().__init__()
+        self.fmp_size = None
+        # ------------------ Basic parameters -------------------
+        self.cfg = cfg
+        self.in_dim = in_dim
+        self.stride       = cfg.out_stride
+        self.num_classes  = cfg.num_classes
+        self.num_cls_head = cfg.num_cls_head
+        self.num_reg_head = cfg.num_reg_head
+        self.act_type     = cfg.head_act
+        self.norm_type    = cfg.head_norm
+
+        # ------------------ Network parameters -------------------
+        ## cls head
+        cls_heads = []
+        self.cls_head_dim = out_dim
+        for i in range(self.num_cls_head):
+            if i == 0:
+                cls_heads.append(
+                    BasicConv(in_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=self.act_type, norm_type=self.norm_type)
+                              )
+            else:
+                cls_heads.append(
+                    BasicConv(self.cls_head_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=self.act_type, norm_type=self.norm_type)
+                              )
+        
+        ## reg head
+        reg_heads = []
+        self.reg_head_dim = out_dim
+        for i in range(self.num_reg_head):
+            if i == 0:
+                reg_heads.append(
+                    BasicConv(in_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=self.act_type, norm_type=self.norm_type)
+                              )
+            else:
+                reg_heads.append(
+                    BasicConv(self.reg_head_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=self.act_type, norm_type=self.norm_type)
+                              )
+        self.cls_heads = nn.Sequential(*cls_heads)
+        self.reg_heads = nn.Sequential(*reg_heads)
+
+        ## pred layers
+        self.cls_pred = nn.Conv2d(self.cls_head_dim, cfg.num_classes, kernel_size=3, padding=1)
+        self.reg_pred = nn.Conv2d(self.reg_head_dim, 4, kernel_size=3, padding=1)
+                
+        # init bias
+        self._init_layers()
+
+    def _init_layers(self):
+        for module in [self.cls_heads, self.reg_heads, self.cls_pred, self.reg_pred]:
+            for layer in module.modules():
+                if isinstance(layer, nn.Conv2d):
+                    torch.nn.init.normal_(layer.weight, mean=0, std=0.01)
+                    if layer.bias is not None:
+                        torch.nn.init.constant_(layer.bias, 0)
+                if isinstance(layer, nn.GroupNorm):
+                    torch.nn.init.constant_(layer.weight, 1)
+                    if layer.bias is not None:
+                        torch.nn.init.constant_(layer.bias, 0)
+        # init the bias of cls pred
+        init_prob = 0.01
+        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+        torch.nn.init.constant_(self.cls_pred.bias, bias_value)
+        
+    def get_anchors(self, level, fmp_size):
+        """
+            fmp_size: (List) [H, W]
+        """
+        # generate grid cells
+        fmp_h, fmp_w = fmp_size
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        # [H, W, 2] -> [HW, 2]
+        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2) + 0.5
+        anchors *= self.stride[level]
+
+        return anchors
+        
+    def decode_boxes(self, pred_deltas, anchors, stride):
+        """
+            pred_deltas: (List[Tensor]) [B, M, 4] or [M, 4] (dx, dy, dw, dh)
+            anchors:     (List[Tensor]) [1, M, 2] or [M, 2]
+        """
+        pred_cxcy = anchors + pred_deltas[..., :2] * stride
+        pred_bwbh = pred_deltas[..., 2:].exp() * stride
+
+        pred_x1y1 = pred_cxcy - 0.5 * pred_bwbh
+        pred_x2y2 = pred_cxcy + 0.5 * pred_bwbh
+
+        pred_box = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
+        return pred_box
+    
+    def forward(self, pyramid_feats, mask=None):
+        all_masks = []
+        all_anchors = []
+        all_cls_preds = []
+        all_reg_preds = []
+        all_box_preds = []
+        for level, feat in enumerate(pyramid_feats):
+            # ------------------- Decoupled head -------------------
+            cls_feat = self.cls_heads(feat)
+            reg_feat = self.reg_heads(feat)
+
+            # ------------------- Generate anchor box -------------------
+            B, _, H, W = cls_feat.size()
+            fmp_size = [H, W]
+            anchors = self.get_anchors(level, fmp_size)   # [M, 4]
+            anchors = anchors.to(cls_feat.device)
+
+            # ------------------- Predict -------------------
+            cls_pred = self.cls_pred(cls_feat)
+            reg_pred = self.reg_pred(reg_feat)
+
+            # ------------------- Process preds -------------------
+            ## [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
+            cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+            reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
+            box_pred = self.decode_boxes(reg_pred, anchors, self.stride[level])
+            ## Adjust mask
+            if mask is not None:
+                # [B, H, W]
+                mask_i = torch.nn.functional.interpolate(mask[None].float(), size=[H, W]).bool()[0]
+                # [B, H, W] -> [B, M]
+                mask_i = mask_i.flatten(1)     
+                all_masks.append(mask_i)
+                
+            all_anchors.append(anchors)
+            all_cls_preds.append(cls_pred)
+            all_reg_preds.append(reg_pred)
+            all_box_preds.append(box_pred)
+
+        outputs = {"pred_cls": all_cls_preds,  # List [B, M, C]
+                   "pred_reg": all_reg_preds,  # List [B, M, 4]
+                   "pred_box": all_box_preds,  # List [B, M, 4]
+                   "anchors": all_anchors,     # List [B, M, 2]
+                   "strides": self.stride,
+                   "mask": all_masks}          # List [B, M,]
+
+        return outputs