yjh0410 1 rok pred
rodič
commit
fea233e774

+ 71 - 0
odlab/config/fcos_config.py

@@ -1,24 +1,32 @@
 # Fully Convolutional One-Stage object detector
 
 def build_fcos_config(args):
+    # Standard FCOS 1x
     if   args.model == 'fcos_r18_1x':
         return Fcos_R18_1x_Config()
     elif args.model == 'fcos_r50_1x':
         return Fcos_R50_1x_Config()
     
+    # Standard FCOS 3x
     elif args.model == 'fcos_r18_3x':
         return Fcos_R18_3x_Config()
     elif args.model == 'fcos_r50_3x':
         return Fcos_R50_3x_Config()
     
+    # Real-time FCOS 3x
     elif args.model == 'fcos_rt_r18_3x':
         return FcosRT_R18_3x_Config()
     elif args.model == 'fcos_rt_r50_3x':
         return FcosRT_R50_3x_Config()
     
+    # E2E FCOS 3x
     elif args.model == 'fcos_e2e_r18_3x':
         return FcosE2E_R18_3x_Config()
 
+    # PSS FCOS 3x
+    elif args.model == 'fcos_pss_r18_3x':
+        return FcosPSS_R18_3x_Config()
+
     else:
         raise NotImplementedError("No config for model: {}".format(args.model))
 
@@ -285,3 +293,66 @@ class FcosE2E_R18_3x_Config(FcosBaseConfig):
             {'name': 'RandomResize'},
         ]
 
+# --------------- PSS-FCOS & 3x scheduler ---------------
+class FcosPSS_R18_3x_Config(FcosBaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        ## Backbone
+        self.backbone = "resnet18"
+        self.max_stride = 32
+        self.out_stride = [8, 16, 32]
+
+        # --------- Neck ---------
+        self.neck = 'basic_fpn'
+        self.fpn_p6_feat = False
+        self.fpn_p7_feat = False
+        self.fpn_p6_from_c5  = False
+
+        # --------- Head ---------
+        self.head = 'fcos_pss_head'
+        self.head_dim = 256
+        self.num_cls_head = 4
+        self.num_reg_head = 4
+        self.head_act     = 'relu'
+        self.head_norm    = 'GN'
+
+        # --------- Post-process ---------
+        self.train_topk = 100
+        self.train_conf_thresh = 0.05
+        self.test_topk = 100
+        self.test_conf_thresh = 0.4
+
+        # --------- Label Assignment ---------
+        self.matcher = 'simota'
+        self.matcher_hpy = {'soft_center_radius': 3.0,
+                            'topk_candidates': 13}
+
+        # --------- Loss weight ---------
+        self.focal_loss_alpha = 0.25
+        self.focal_loss_gamma = 2.0
+        self.loss_cls_weight  = 1.0
+        self.loss_reg_weight  = 2.0
+        self.loss_pss_weight  = 1.0
+
+        # --------- Train epoch ---------
+        self.max_epoch = 36         # 3x
+        self.lr_epoch  = [24, 33]   # 3x
+
+        # --------- Data process ---------
+        ## input size
+        self.train_min_size = [256, 288, 320, 352, 384, 416, 448, 480, 512, 544, 576, 608]   # short edge of image
+        self.train_max_size = 900
+        self.test_min_size  = [512]
+        self.test_max_size  = 736
+        ## Pixel mean & std
+        self.pixel_mean = [0.485, 0.456, 0.406]
+        self.pixel_std  = [0.229, 0.224, 0.225]
+        ## Transforms
+        self.box_format = 'xyxy'
+        self.normalize_coords = False
+        self.detr_style = False
+        self.trans_config = [
+            {'name': 'RandomHFlip'},
+            {'name': 'RandomResize'},
+        ]
+

+ 4 - 0
odlab/models/detectors/__init__.py

@@ -3,6 +3,7 @@ import torch
 
 from .fcos.build     import build_fcos, build_fcos_rt
 from .fcos_e2e.build import build_fcos_e2e
+from .fcos_pss.build import build_fcos_pss
 from .yolof.build    import build_yolof
 from .detr.build     import build_detr
 
@@ -15,6 +16,9 @@ def build_model(args, cfg, is_val=False):
     ## E2E-FCOS
     elif 'fcos_e2e' in args.model:
         model, criterion = build_fcos_e2e(cfg, is_val)
+    ## PSS-FCOS
+    elif 'fcos_pss' in args.model:
+        model, criterion = build_fcos_pss(cfg, is_val)
     ## FCOS    
     elif 'fcos' in args.model:
         model, criterion = build_fcos(cfg, is_val)

+ 20 - 0
odlab/models/detectors/fcos_pss/README.md

@@ -0,0 +1,20 @@
+# Empirical research on End-to-End FCOS
+Inspired by the YOLOv10, I recently make the empirical research on FCOS to evaluate the **End-to-End detection** paradigm.
+
+## Experiments
+
+- COCO
+
+Incredibly, the FPS of the three FCOS are almost the same!
+
+| Model                | Sclae      | FPS<sup>FP32<br>RTX 4060 | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | Weight | Logs |
+|----------------------|------------|--------------------------|------------------------|-------------------|--------|------|
+| FCOS_RT_R18_3x       |  512,736   |           56             |          35.8          |        53.3       | [ckpt](https://github.com/yjh0410/E2E_FCOS/releases/download/fcos_weight/fcos_rt_r18_3x_coco.pth) | [log](https://github.com/yjh0410/E2E_FCOS/releases/download/fcos_weight/FCOS-RT-R18-3x.txt) |
+| FCOS_RT_R18_3x (O2O) |  512,736   |           56             |          30.9          |        48.8       | [ckpt](https://github.com/yjh0410/E2E_FCOS/releases/download/fcos_weight/fcos_rt_r18_3x_top1_coco.pth) | [log](https://github.com/yjh0410/E2E_FCOS/releases/download/fcos_weight/FCOS-RT-R18-3x-COCO-top1.txt) |
+| FCOS_E2E_R18_3x      |  512,736   |           56             |          34.1          |        50.6       | [ckpt](https://github.com/yjh0410/E2E_FCOS/releases/download/fcos_weight/fcos_e2e_r18_3x_coco.pth) | [log](https://github.com/yjh0410/E2E_FCOS/releases/download/fcos_weight/FCOS-E2E-R18-3x-COCO.txt) |
+
+For **FCOS_RT_R18_3x**, we only use one-to-many assinger to train `FCOS-RT-R18-3x` and evaluate it with NMS.
+
+For **FCOS_RT_R18_3x (O2O)**, we only use one-to-one assinger to train `FCOS-RT-R18-3x` and evaluate it without NMS.
+
+For **FCOS_E2E_R18_3x**, we deploy two parallel detection head, one using one-to-many assinger (o2m head) and the other using one-to-one assinger (o2o head). To avoid conflicts between the gradients returned by o2o head and o2m head, we truncate the gradients returned from o2o head to the backbone and neck, and only allow the gradients returned from o2m head to update the backbone and neck. This operation is consistent with the practice of YOLOv10. For evaluation, we remove the o2m head and only use o2o head without NMS.

+ 18 - 0
odlab/models/detectors/fcos_pss/build.py

@@ -0,0 +1,18 @@
+from .fcos import FcosPSS
+from .criterion import SetCriterion
+
+
+def build_fcos_pss(cfg, is_val=False):
+    # ------------ build object detector ------------
+    ## RT-FCOS    
+    model = FcosPSS(cfg          = cfg,
+                    conf_thresh  = cfg.train_conf_thresh if is_val else cfg.test_conf_thresh,
+                    topk_results = cfg.train_topk        if is_val else cfg.test_topk,
+                    )
+    criterion = None
+    if is_val:
+        # build criterion for training
+        criterion = SetCriterion(cfg)
+
+    return model, criterion
+    

+ 169 - 0
odlab/models/detectors/fcos_pss/criterion.py

@@ -0,0 +1,169 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from utils.misc import sigmoid_focal_loss
+from utils.box_ops import get_ious
+from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
+
+from .matcher import AlignedOTAMatcher
+
+
+class SetCriterion(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        # ------------- Basic parameters -------------
+        self.cfg = cfg
+        self.num_classes = cfg.num_classes
+        # ------------- Focal loss -------------
+        self.alpha = cfg.focal_loss_alpha
+        self.gamma = cfg.focal_loss_gamma
+        # ------------- Loss weight -------------
+        self.weight_dict = {'loss_cls': cfg.loss_cls_weight,
+                            'loss_reg': cfg.loss_reg_weight,
+                            'loss_pss': cfg.loss_pss_weight}
+        # ------------- Matcher & Loss weight -------------
+        self.matcher_cfg = cfg.matcher_hpy
+        self.matcher = AlignedOTAMatcher(cfg.num_classes,
+                                         cfg.matcher_hpy['soft_center_radius'],
+                                         cfg.matcher_hpy['topk_candidates'],
+                                         )
+
+    def loss_labels(self, pred_cls, target, beta=2.0, num_boxes=1.0):
+        # Quality FocalLoss
+        """
+            pred_cls: (torch.Tensor): [N, C]。
+            target:   (tuple([torch.Tensor], [torch.Tensor])): label -> (N,), score -> (N)
+        """
+        label, score = target
+        pred_sigmoid = pred_cls.sigmoid()
+        scale_factor = pred_sigmoid
+        zerolabel = scale_factor.new_zeros(pred_cls.shape)
+
+        ce_loss = F.binary_cross_entropy_with_logits(
+            pred_cls, zerolabel, reduction='none') * scale_factor.pow(beta)
+        
+        bg_class_ind = pred_cls.shape[-1]
+        pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
+        if pos.shape[0] > 0:
+            pos_label = label[pos].long()
+
+            scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
+
+            ce_loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
+                pred_cls[pos, pos_label], score[pos],
+                reduction='none') * scale_factor.abs().pow(beta)
+
+        return ce_loss.sum() / num_boxes
+    
+    def loss_bboxes(self, pred_box, gt_box, num_boxes=1.0, box_weight=None):
+        ious = get_ious(pred_box, gt_box, box_mode="xyxy", iou_type='giou')
+        loss_box = 1.0 - ious
+
+        if box_weight is not None:
+            loss_box = loss_box.squeeze(-1) * box_weight
+
+        return loss_box.sum() / num_boxes
+
+    def loss_pss(self, pred_pss, target, num_boxes=1.0):
+        loss_pss = sigmoid_focal_loss(pred_pss, target, alpha=0.25, gamma=2.0)
+        return loss_pss.sum() / num_boxes
+
+    def forward(self, outputs, targets):
+        """
+            outputs['pred_cls']: (Tensor) [B, M, C]
+            outputs['pred_reg']: (Tensor) [B, M, 4]
+            outputs['pred_box']: (Tensor) [B, M, 4]
+            outputs['strides']: (List) [8, 16, 32, ...] stride of the model output
+            targets: (List) [dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}, ...]
+        """
+        # -------------------- Pre-process --------------------
+        bs          = outputs['pred_cls'][0].shape[0]
+        device      = outputs['pred_cls'][0].device
+        fpn_strides = outputs['strides']
+        anchors     = outputs['anchors']
+        # Reshape: List([B, M, C]) -> [B, M, C]
+        cls_preds = torch.cat(outputs['pred_cls'], dim=1)
+        pss_preds = torch.cat(outputs['pred_pss'], dim=1)
+        box_preds = torch.cat(outputs['pred_box'], dim=1)
+        masks = ~torch.cat(outputs['mask'], dim=1).view(-1)
+
+        # -------------------- Label Assignment --------------------
+        cls_targets = []
+        pss_targets = []
+        box_targets = []
+        assign_metrics = []
+        for batch_idx in range(bs):
+            tgt_labels = targets[batch_idx]["labels"].to(device)  # [N,]
+            tgt_bboxes = targets[batch_idx]["boxes"].to(device)   # [N, 4]
+            # refine target
+            tgt_boxes_wh = tgt_bboxes[..., 2:] - tgt_bboxes[..., :2]
+            min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
+            keep = (min_tgt_size >= 8)
+            tgt_bboxes = tgt_bboxes[keep]
+            tgt_labels = tgt_labels[keep]
+            # label assignment
+            assigned_result = self.matcher(fpn_strides=fpn_strides,
+                                           anchors=anchors,
+                                           pred_cls=cls_preds[batch_idx].detach(),
+                                           pred_box=box_preds[batch_idx].detach(),
+                                           gt_labels=tgt_labels,
+                                           gt_bboxes=tgt_bboxes
+                                           )
+            cls_targets.append(assigned_result['assigned_labels'])
+            pss_targets.append(assigned_result['assigned_pss'])
+            box_targets.append(assigned_result['assigned_bboxes'])
+            assign_metrics.append(assigned_result['assign_metrics'])
+
+        # List[B, M, C] -> Tensor[BM, C]
+        cls_targets = torch.cat(cls_targets, dim=0)  # [BM, C]
+        pss_targets = torch.cat(pss_targets, dim=0)  # [BM,]
+        box_targets = torch.cat(box_targets, dim=0)  # [BM, 4]
+        assign_metrics = torch.cat(assign_metrics, dim=0)  # [BM,]
+
+        valid_idxs = (cls_targets >= 0) & masks
+        foreground_idxs = (cls_targets >= 0) & (cls_targets != self.num_classes)
+
+        num_fgs = assign_metrics.sum()
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_fgs)
+        num_fgs = torch.clamp(num_fgs / get_world_size(), min=1).item()
+
+        num_targets = pss_targets[valid_idxs].sum()
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_targets)
+        num_targets = torch.clamp(num_targets / get_world_size(), min=1).item()
+
+        # -------------------- Pos-Sample selector loss --------------------
+        pss_preds = pss_preds.view(-1)[valid_idxs]
+        loss_pss = self.loss_pss(pss_preds, pss_targets[valid_idxs], num_targets)
+
+        # -------------------- Classification loss --------------------
+        cls_preds = cls_preds.view(-1, self.num_classes)[valid_idxs]
+        pss_preds = pss_preds.unsqueeze(-1)
+        qfl_targets = (cls_targets[valid_idxs], assign_metrics[valid_idxs])
+        loss_cls = self.loss_labels(cls_preds, qfl_targets, 2.0, num_fgs)
+
+        # -------------------- Regression loss --------------------
+        box_preds_pos = box_preds.view(-1, 4)[foreground_idxs]
+        box_targets_pos = box_targets[foreground_idxs]
+        box_weight = assign_metrics[foreground_idxs]
+        loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos, num_fgs, box_weight)
+
+        total_loss = loss_cls * self.weight_dict["loss_cls"] + \
+                     loss_box * self.weight_dict["loss_reg"] + \
+                     loss_pss * self.weight_dict["loss_pss"]
+        loss_dict = dict(
+                loss_cls = loss_cls,
+                loss_reg = loss_box,
+                loss_pss = loss_pss,
+                losses   = total_loss,
+        )
+
+        return loss_dict
+    
+
+if __name__ == "__main__":
+    pass

+ 130 - 0
odlab/models/detectors/fcos_pss/fcos.py

@@ -0,0 +1,130 @@
+import copy
+import torch
+import torch.nn as nn
+
+# --------------- Model components ---------------
+from ...backbone import build_backbone
+from ...neck     import build_neck
+from ...head     import build_head
+
+
+# --------------------- End-to-End RT-FCOS ---------------------
+class FcosPSS(nn.Module):
+    def __init__(self, 
+                 cfg,
+                 conf_thresh  :float = 0.05,
+                 topk_results :int   = 1000,
+                 ):
+        super(FcosPSS, self).__init__()
+        # ---------------------- Basic Parameters ----------------------
+        self.conf_thresh  = conf_thresh
+        self.num_classes  = cfg.num_classes
+        self.topk_results = topk_results
+
+        # ---------------------- Network Parameters ----------------------
+        ## Backbone
+        self.backbone, pyramid_feats = build_backbone(cfg)
+
+        ## Neck
+        self.backbone_fpn = build_neck(cfg, pyramid_feats, cfg.head_dim)
+
+        ## Heads
+        self.detection_head = build_head(cfg, cfg.head_dim, cfg.head_dim)
+
+    def post_process(self, cls_preds, box_preds, pss_preds):
+        """
+        Input:
+            cls_preds: List(Tensor) [[B, H x W, C], ...]
+            box_preds: List(Tensor) [[B, H x W, 4], ...]
+            pss_preds: List(Tensor) [[B, H x W, 1], ...]
+        """
+        all_scores = []
+        all_labels = []
+        all_bboxes = []
+        
+        for cls_pred_i, box_pred_i, pss_pred_i in zip(cls_preds, box_preds, pss_preds):
+            cls_pred_i = cls_pred_i[0]
+            box_pred_i = box_pred_i[0]
+            pss_pred_i = pss_pred_i[0]
+            
+            # [H, W, C] -> [HWC,]
+            scores_i = (cls_pred_i.sigmoid() * pss_pred_i.sigmoid()).flatten()
+
+            # Keep top k top scoring indices only.
+            num_topk = min(self.topk_results, box_pred_i.size(0))
+
+            # torch.sort is actually faster than .topk (at least on GPUs)
+            predicted_prob, topk_idxs = scores_i.sort(descending=True)
+            topk_scores = predicted_prob[:num_topk]
+            topk_idxs = topk_idxs[:num_topk]
+
+            # filter out the proposals with low confidence score
+            keep_idxs = topk_scores > self.conf_thresh
+            topk_idxs = topk_idxs[keep_idxs]
+
+            # final scores
+            scores = topk_scores[keep_idxs]
+            # final labels
+            labels = topk_idxs % self.num_classes
+            # final bboxes
+            anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+            bboxes = box_pred_i[anchor_idxs]
+
+            all_scores.append(scores)
+            all_labels.append(labels)
+            all_bboxes.append(bboxes)
+
+        scores = torch.cat(all_scores)
+        labels = torch.cat(all_labels)
+        bboxes = torch.cat(all_bboxes)
+
+        # to cpu & numpy
+        scores = scores.cpu().numpy()
+        labels = labels.cpu().numpy()
+        bboxes = bboxes.cpu().numpy()
+
+        return bboxes, scores, labels
+
+    def inference(self, src):
+        # ---------------- Backbone ----------------
+        pyramid_feats = self.backbone(src)
+
+        # ---------------- Neck ----------------
+        pyramid_feats = self.backbone_fpn(pyramid_feats)
+
+        # ---------------- Heads ----------------
+        outputs  = self.detection_head(pyramid_feats)
+        cls_pred = outputs["pred_cls"]
+        box_pred = outputs["pred_box"]
+        pss_pred = outputs["pred_pss"]
+
+        # Post-process (no NMS)
+        bboxes, scores, labels = self.post_process(cls_pred, box_pred, pss_pred)
+
+        # Normalize bbox
+        bboxes[..., 0::2] /= src.shape[-1]
+        bboxes[..., 1::2] /= src.shape[-2]
+        bboxes = bboxes.clip(0., 1.)
+
+        outputs = {
+            'scores': scores,
+            'labels': labels,
+            'bboxes': bboxes
+        }
+
+        return outputs
+
+    def forward(self, src, src_mask=None):
+        if not self.training:
+            return self.inference(src)
+        else:
+            # ---------------- Backbone ----------------
+            pyramid_feats = self.backbone(src)
+
+            # ---------------- Neck ----------------
+            pyramid_feats = self.backbone_fpn(pyramid_feats)
+
+            # ---------------- Heads ----------------
+            outputs = self.detection_head(pyramid_feats, src_mask)
+
+            return outputs 

+ 188 - 0
odlab/models/detectors/fcos_pss/matcher.py

@@ -0,0 +1,188 @@
+import torch
+import torch.nn.functional as F
+
+from utils.box_ops import box_iou
+
+
+class AlignedOTAMatcher(object):
+    """
+    This code referenced to https://github.com/open-mmlab/mmyolo/models/task_modules/assigners/batch_dsl_assigner.py
+    """
+    def __init__(self,
+                 num_classes,
+                 soft_center_radius=3.0,
+                 topk_candidates=13,
+                 ):
+        self.num_classes = num_classes
+        self.soft_center_radius = soft_center_radius
+        self.topk_candidates = topk_candidates
+
+    @torch.no_grad()
+    def __call__(self, 
+                 fpn_strides, 
+                 anchors, 
+                 pred_cls, 
+                 pred_box,
+                 gt_labels,
+                 gt_bboxes):
+        # [M,]
+        strides = torch.cat([torch.ones_like(anchor_i[:, 0]) * stride_i
+                                for stride_i, anchor_i in zip(fpn_strides, anchors)], dim=-1)
+        # List[F, M, 2] -> [M, 2]
+        num_gt = len(gt_labels)
+        anchors = torch.cat(anchors, dim=0)
+
+        # check gt
+        if num_gt == 0 or gt_bboxes.max().item() == 0.:
+            return {
+                'assigned_labels': gt_labels.new_full(pred_cls[..., 0].shape, self.num_classes).long(),
+                'assigned_bboxes': gt_bboxes.new_full(pred_box.shape, 0),
+                'assign_metrics':  gt_bboxes.new_full(pred_cls[..., 0].shape, 0),
+                'assigned_pss':    gt_labels.new_full(pred_cls[..., 0].shape, 0).float(),
+            }
+        
+        # get inside points: [N, M]
+        is_in_gt = self.find_inside_points(gt_bboxes, anchors)
+        valid_mask = is_in_gt.sum(dim=0) > 0  # [M,]
+
+        # ----------------------- Soft center prior -----------------------
+        gt_center = (gt_bboxes[..., :2] + gt_bboxes[..., 2:]) / 2.0
+        distance = (anchors.unsqueeze(0) - gt_center.unsqueeze(1)
+                    ).pow(2).sum(-1).sqrt() / strides.unsqueeze(0)  # [N, M]
+        distance = distance * valid_mask.unsqueeze(0)
+        soft_center_prior = torch.pow(10, distance - self.soft_center_radius)
+
+        # ----------------------- Regression cost -----------------------
+        pair_wise_ious, _ = box_iou(gt_bboxes, pred_box)  # [N, M]
+        pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) * 3.0
+
+        # ----------------------- Classification cost -----------------------
+        ## select the predicted scores corresponded to the gt_labels
+        pairwise_pred_scores = pred_cls.permute(1, 0)  # [M, C] -> [C, M]
+        pairwise_pred_scores = pairwise_pred_scores[gt_labels.long(), :].float()   # [N, M]
+        ## scale factor
+        scale_factor = (pair_wise_ious - pairwise_pred_scores.sigmoid()).abs().pow(2.0)
+        ## cls cost
+        pair_wise_cls_loss = F.binary_cross_entropy_with_logits(
+            pairwise_pred_scores, pair_wise_ious,
+            reduction="none") * scale_factor # [N, M]
+            
+        del pairwise_pred_scores
+
+        ## foreground cost matrix
+        cost_matrix = pair_wise_cls_loss + pair_wise_ious_loss + soft_center_prior
+        max_pad_value = torch.ones_like(cost_matrix) * 1e9
+        cost_matrix = torch.where(valid_mask[None].repeat(num_gt, 1),   # [N, M]
+                                  cost_matrix, max_pad_value)
+
+        # ----------------------- Dynamic label assignment -----------------------
+        matched_pred_ious, matched_gt_inds, fg_mask_inboxes = self.dynamic_k_matching(
+            cost_matrix, pair_wise_ious, num_gt)
+        matched_pss_inds, fg_mask_top1 = self.topk_one_matching(cost_matrix, pair_wise_ious, num_gt)
+        del pair_wise_cls_loss, cost_matrix, pair_wise_ious, pair_wise_ious_loss
+
+        # ----------------------- Process assigned labels -----------------------
+        assigned_labels = gt_labels.new_full(pred_cls[..., 0].shape,
+                                             self.num_classes)  # [M,]
+        assigned_labels[fg_mask_inboxes] = gt_labels[matched_gt_inds].squeeze(-1)
+        assigned_labels = assigned_labels.long()  # [M,]
+
+        assigned_bboxes = gt_bboxes.new_full(pred_box.shape, 0)        # [M, 4]
+        assigned_bboxes[fg_mask_inboxes] = gt_bboxes[matched_gt_inds]  # [M, 4]
+
+        assign_metrics = gt_bboxes.new_full(pred_cls[..., 0].shape, 0) # [M,]
+        assign_metrics[fg_mask_inboxes] = matched_pred_ious            # [M,]
+
+        assigned_pss = gt_labels.new_zeros(pred_cls[..., 0].shape)    # [M,]
+        assigned_pss[fg_mask_top1] = torch.ones_like(gt_labels)[matched_pss_inds].squeeze(-1)
+        assigned_pss = assigned_pss.float()
+
+        assigned_dict = dict(
+            assigned_labels = assigned_labels,
+            assigned_bboxes = assigned_bboxes,
+            assign_metrics  = assign_metrics,
+            assigned_pss    = assigned_pss,
+            )
+        
+        return assigned_dict
+
+    def find_inside_points(self, gt_bboxes, anchors):
+        """
+            gt_bboxes: Tensor -> [N, 2]
+            anchors:   Tensor -> [M, 2]
+        """
+        num_anchors = anchors.shape[0]
+        num_gt = gt_bboxes.shape[0]
+
+        anchors_expand = anchors.unsqueeze(0).repeat(num_gt, 1, 1)           # [N, M, 2]
+        gt_bboxes_expand = gt_bboxes.unsqueeze(1).repeat(1, num_anchors, 1)  # [N, M, 4]
+
+        # offset
+        lt = anchors_expand - gt_bboxes_expand[..., :2]
+        rb = gt_bboxes_expand[..., 2:] - anchors_expand
+        bbox_deltas = torch.cat([lt, rb], dim=-1)
+
+        is_in_gts = bbox_deltas.min(dim=-1).values > 0
+
+        return is_in_gts
+    
+    def dynamic_k_matching(self, cost_matrix, pairwise_ious, num_gt):
+        matching_matrix = torch.zeros_like(cost_matrix, dtype=torch.uint8)
+        # select candidate topk ious for dynamic-k calculation
+        candidate_topk = min(self.topk_candidates, pairwise_ious.size(1))
+        topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=1)
+        # calculate dynamic k for each gt
+        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
+
+        # sorting the batch cost matirx is faster than topk
+        _, sorted_indices = torch.sort(cost_matrix, dim=1)
+        for gt_idx in range(num_gt):
+            topk_ids = sorted_indices[gt_idx, :dynamic_ks[gt_idx]]
+            matching_matrix[gt_idx, :][topk_ids] = 1
+
+        del topk_ious, dynamic_ks, topk_ids
+
+        prior_match_gt_mask = matching_matrix.sum(0) > 1
+        if prior_match_gt_mask.sum() > 0:
+            cost_min, cost_argmin = torch.min(
+                cost_matrix[:, prior_match_gt_mask], dim=0)
+            matching_matrix[:, prior_match_gt_mask] *= 0
+            matching_matrix[cost_argmin, prior_match_gt_mask] = 1
+
+        # get foreground mask inside box and center prior
+        fg_mask_inboxes = matching_matrix.sum(0) > 0
+        matched_pred_ious = (matching_matrix *
+                             pairwise_ious).sum(0)[fg_mask_inboxes]
+        matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
+
+        return matched_pred_ious, matched_gt_inds, fg_mask_inboxes
+
+    def topk_one_matching(self, cost_matrix, pairwise_ious, num_gt):
+        matching_matrix = torch.zeros_like(cost_matrix, dtype=torch.uint8)
+        # select candidate topk ious for dynamic-k calculation
+        candidate_topk = min(1, pairwise_ious.size(1))
+        topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=1)
+        # calculate dynamic k for each gt
+        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
+
+        # sorting the batch cost matirx is faster than topk
+        _, sorted_indices = torch.sort(cost_matrix, dim=1)
+        for gt_idx in range(num_gt):
+            topk_ids = sorted_indices[gt_idx, :dynamic_ks[gt_idx]]
+            matching_matrix[gt_idx, :][topk_ids] = 1
+
+        del topk_ious, dynamic_ks, topk_ids
+
+        prior_match_gt_mask = matching_matrix.sum(0) > 1
+        if prior_match_gt_mask.sum() > 0:
+            cost_min, cost_argmin = torch.min(
+                cost_matrix[:, prior_match_gt_mask], dim=0)
+            matching_matrix[:, prior_match_gt_mask] *= 0
+            matching_matrix[cost_argmin, prior_match_gt_mask] = 1
+
+        # get foreground mask inside box and center prior
+        fg_mask_inboxes = matching_matrix.sum(0) > 0
+        matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
+
+        return matched_gt_inds, fg_mask_inboxes
+    

+ 3 - 1
odlab/models/head/__init__.py

@@ -1,5 +1,5 @@
 from .yolof_head     import YolofHead
-from .fcos_head      import FcosHead, FcosRTHead
+from .fcos_head      import FcosHead, FcosRTHead, FcosPSSHead
 
 
 # build head
@@ -11,6 +11,8 @@ def build_head(cfg, in_dim, out_dim):
         model = FcosHead(cfg, in_dim, out_dim)
     elif cfg.head == 'fcos_rt_head':
         model = FcosRTHead(cfg, in_dim, out_dim)
+    elif cfg.head == 'fcos_pss_head':
+        model = FcosPSSHead(cfg, in_dim, out_dim)
     elif cfg.head == 'yolof_head':
         model = YolofHead(cfg, in_dim, out_dim)
 

+ 165 - 0
odlab/models/head/fcos_head.py

@@ -334,3 +334,168 @@ class FcosRTHead(nn.Module):
                    "mask": all_masks}          # List [B, M,]
 
         return outputs 
+
+class FcosPSSHead(nn.Module):
+    def __init__(self, cfg, in_dim, out_dim,):
+        super().__init__()
+        self.fmp_size = None
+        # ------------------ Basic parameters -------------------
+        self.cfg = cfg
+        self.in_dim = in_dim
+        self.stride       = cfg.out_stride
+        self.num_classes  = cfg.num_classes
+        self.num_cls_head = cfg.num_cls_head
+        self.num_reg_head = cfg.num_reg_head
+        self.act_type     = cfg.head_act
+        self.norm_type    = cfg.head_norm
+
+        # ------------------ Model parameters -------------------
+        ## cls head
+        cls_heads = []
+        self.cls_head_dim = out_dim
+        for i in range(self.num_cls_head):
+            if i == 0:
+                cls_heads.append(
+                    BasicConv(in_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=self.act_type, norm_type=self.norm_type)
+                              )
+            else:
+                cls_heads.append(
+                    BasicConv(self.cls_head_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=self.act_type, norm_type=self.norm_type)
+                              )
+        
+        ## reg head
+        reg_heads = []
+        self.reg_head_dim = out_dim
+        for i in range(self.num_reg_head):
+            if i == 0:
+                reg_heads.append(
+                    BasicConv(in_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=self.act_type, norm_type=self.norm_type)
+                              )
+            else:
+                reg_heads.append(
+                    BasicConv(self.reg_head_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=self.act_type, norm_type=self.norm_type)
+                              )
+        self.cls_heads = nn.Sequential(*cls_heads)
+        self.reg_heads = nn.Sequential(*reg_heads)
+
+        ## Pred layers
+        self.cls_pred = nn.Conv2d(self.cls_head_dim, cfg.num_classes, kernel_size=3, padding=1)
+        self.reg_pred = nn.Conv2d(self.reg_head_dim, 4, kernel_size=3, padding=1)
+        self.pss_pred = nn.Sequential(
+            BasicConv(self.reg_head_dim, self.reg_head_dim, kernel_size=3, padding=1, stride=1, 
+                      act_type=self.act_type, norm_type=self.norm_type),
+            BasicConv(self.reg_head_dim, self.reg_head_dim, kernel_size=3, padding=1, stride=1, 
+                      act_type=self.act_type, norm_type=self.norm_type),
+            nn.Conv2d(self.cls_head_dim, 1, kernel_size=3, padding=1)
+        )
+                
+        # init bias
+        self._init_layers()
+
+    def _init_layers(self):
+        for module in [self.cls_heads, self.reg_heads, self.cls_pred, self.reg_pred, self.pss_pred]:
+            for layer in module.modules():
+                if isinstance(layer, nn.Conv2d):
+                    torch.nn.init.normal_(layer.weight, mean=0, std=0.01)
+                    if layer.bias is not None:
+                        torch.nn.init.constant_(layer.bias, 0)
+                if isinstance(layer, nn.GroupNorm):
+                    torch.nn.init.constant_(layer.weight, 1)
+                    if layer.bias is not None:
+                        torch.nn.init.constant_(layer.bias, 0)
+        # init the bias of cls pred
+        init_prob = 0.01
+        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+        torch.nn.init.constant_(self.cls_pred.bias, bias_value)
+        torch.nn.init.constant_(self.pss_pred[-1].bias, bias_value)
+        
+    def get_anchors(self, level, fmp_size):
+        """
+            fmp_size: (List) [H, W]
+        """
+        # generate grid cells
+        fmp_h, fmp_w = fmp_size
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        # [H, W, 2] -> [HW, 2]
+        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2) + 0.5
+        anchors *= self.stride[level]
+
+        return anchors
+        
+    def decode_boxes(self, pred_deltas, anchors, stride):
+        """
+            pred_deltas: (List[Tensor]) [B, M, 4] or [M, 4] (dx, dy, dw, dh)
+            anchors:     (List[Tensor]) [1, M, 2] or [M, 2]
+        """
+        pred_cxcy = anchors + pred_deltas[..., :2] * stride
+        pred_bwbh = pred_deltas[..., 2:].exp() * stride
+
+        pred_x1y1 = pred_cxcy - 0.5 * pred_bwbh
+        pred_x2y2 = pred_cxcy + 0.5 * pred_bwbh
+
+        pred_box = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
+        return pred_box
+    
+    def forward(self, pyramid_feats, mask=None):
+        all_masks = []
+        all_anchors = []
+        all_cls_preds = []
+        all_pss_preds = []
+        all_reg_preds = []
+        all_box_preds = []
+        for level, feat in enumerate(pyramid_feats):
+            # ------------------- Decoupled head -------------------
+            cls_feat = self.cls_heads(feat)
+            reg_feat = self.reg_heads(feat)
+
+            # ------------------- Generate anchor box -------------------
+            B, _, H, W = cls_feat.size()
+            fmp_size = [H, W]
+            anchors = self.get_anchors(level, fmp_size)   # [M, 4]
+            anchors = anchors.to(cls_feat.device)
+
+            # ------------------- Predict -------------------
+            cls_pred = self.cls_pred(cls_feat)
+            reg_pred = self.reg_pred(reg_feat)
+            pss_pred = self.pss_pred(reg_feat.detach())
+
+            # ------------------- Process preds -------------------
+            ## [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
+            cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+            pss_pred = pss_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 1)
+            reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
+            box_pred = self.decode_boxes(reg_pred, anchors, self.stride[level])
+            ## Adjust mask
+            if mask is not None:
+                # [B, H, W]
+                mask_i = torch.nn.functional.interpolate(mask[None].float(), size=[H, W]).bool()[0]
+                # [B, H, W] -> [B, M]
+                mask_i = mask_i.flatten(1)     
+                all_masks.append(mask_i)
+                
+            all_anchors.append(anchors)
+            all_cls_preds.append(cls_pred)
+            all_pss_preds.append(pss_pred)
+            all_reg_preds.append(reg_pred)
+            all_box_preds.append(box_pred)
+
+        outputs = {"pred_cls": all_cls_preds,  # List [B, M, C]
+                   "pred_pss": all_pss_preds,  # List [B, M, 1]
+                   "pred_reg": all_reg_preds,  # List [B, M, 4]
+                   "pred_box": all_box_preds,  # List [B, M, 4]
+                   "anchors":  all_anchors,    # List [B, M, 2]
+                   "mask":     all_masks,      # List [B, M,]
+                   "strides":  self.stride,
+                   }
+
+        return outputs
+