Browse Source

add e2e fcos

yjh0410 1 year ago
parent
commit
81e06645b8

+ 66 - 0
odlab/config/fcos_config.py

@@ -16,6 +16,9 @@ def build_fcos_config(args):
     elif args.model == 'fcos_rt_r50_3x':
         return FcosRT_R50_3x_Config()
     
+    elif args.model == 'fcos_e2e_r18_3x':
+        return FcosE2E_R18_3x_Config()
+
     else:
         raise NotImplementedError("No config for model: {}".format(args.model))
 
@@ -219,3 +222,66 @@ class FcosRT_R50_3x_Config(FcosRT_R18_3x_Config):
         super().__init__()
         # --------- Backbone ---------
         self.backbone = "resnet50"
+
+# --------------- E2E-FCOS & 3x scheduler ---------------
+class FcosE2E_R18_3x_Config(FcosBaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        ## Backbone
+        self.backbone = "resnet18"
+        self.max_stride = 32
+        self.out_stride = [8, 16, 32]
+
+        # --------- Neck ---------
+        self.neck = 'basic_fpn'
+        self.fpn_p6_feat = False
+        self.fpn_p7_feat = False
+        self.fpn_p6_from_c5  = False
+
+        # --------- Head ---------
+        self.head = 'fcos_rt_head'
+        self.head_dim = 256
+        self.num_cls_head = 4
+        self.num_reg_head = 4
+        self.head_act     = 'relu'
+        self.head_norm    = 'GN'
+
+        # --------- Post-process ---------
+        self.train_topk = 100
+        self.train_conf_thresh = 0.05
+        self.test_topk = 100
+        self.test_conf_thresh = 0.4
+
+        # --------- Label Assignment ---------
+        self.matcher = 'simota'
+        self.matcher_hpy = {'soft_center_radius': 3.0,
+                            'topk_candidates': 13}
+
+        # --------- Loss weight ---------
+        self.focal_loss_alpha = 0.25
+        self.focal_loss_gamma = 2.0
+        self.loss_cls_weight  = 1.0
+        self.loss_reg_weight  = 2.0
+
+        # --------- Train epoch ---------
+        self.max_epoch = 36         # 3x
+        self.lr_epoch  = [24, 33]   # 3x
+
+        # --------- Data process ---------
+        ## input size
+        self.train_min_size = [256, 288, 320, 352, 384, 416, 448, 480, 512, 544, 576, 608]   # short edge of image
+        self.train_max_size = 900
+        self.test_min_size  = [512]
+        self.test_max_size  = 736
+        ## Pixel mean & std
+        self.pixel_mean = [0.485, 0.456, 0.406]
+        self.pixel_std  = [0.229, 0.224, 0.225]
+        ## Transforms
+        self.box_format = 'xyxy'
+        self.normalize_coords = False
+        self.detr_style = False
+        self.trans_config = [
+            {'name': 'RandomHFlip'},
+            {'name': 'RandomResize'},
+        ]
+

+ 7 - 3
odlab/models/detectors/__init__.py

@@ -1,9 +1,10 @@
 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
 import torch
 
-from .fcos.build  import build_fcos, build_fcos_rt
-from .yolof.build import build_yolof
-from .detr.build  import build_detr
+from .fcos.build     import build_fcos, build_fcos_rt
+from .fcos_e2e.build import build_fcos_e2e
+from .yolof.build    import build_yolof
+from .detr.build     import build_detr
 
 
 def build_model(args, cfg, is_val=False):
@@ -11,6 +12,9 @@ def build_model(args, cfg, is_val=False):
     ## RT-FCOS    
     if   'fcos_rt' in args.model:
         model, criterion = build_fcos_rt(cfg, is_val)
+    ## E2E-FCOS
+    elif 'fcos_e2e' in args.model:
+        model, criterion = build_fcos_e2e(cfg, is_val)
     ## FCOS    
     elif 'fcos' in args.model:
         model, criterion = build_fcos(cfg, is_val)

+ 20 - 0
odlab/models/detectors/fcos_e2e/README.md

@@ -0,0 +1,20 @@
+# Empirical research on End-to-End FCOS
+Inspired by the YOLOv10, I recently make the empirical research on FCOS to evaluate the **End-to-End detection** paradigm.
+
+## Experiments
+
+- COCO
+
+Incredibly, the FPS of the three FCOS are almost the same!
+
+| Model                | Sclae      | FPS<sup>FP32<br>RTX 4060 | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | Weight | Logs |
+|----------------------|------------|--------------------------|------------------------|-------------------|--------|------|
+| FCOS_RT_R18_3x       |  512,736   |           56             |          35.8          |        53.3       | [ckpt](https://github.com/yjh0410/E2E_FCOS/releases/download/fcos_weight/fcos_rt_r18_3x_coco.pth) | [log](https://github.com/yjh0410/E2E_FCOS/releases/download/fcos_weight/FCOS-RT-R18-3x.txt) |
+| FCOS_RT_R18_3x (O2O) |  512,736   |           56             |          30.9          |        48.8       | [ckpt](https://github.com/yjh0410/E2E_FCOS/releases/download/fcos_weight/fcos_rt_r18_3x_top1_coco.pth) | [log](https://github.com/yjh0410/E2E_FCOS/releases/download/fcos_weight/FCOS-RT-R18-3x-COCO-top1.txt) |
+| FCOS_E2E_R18_3x      |  512,736   |           56             |          34.1          |        50.6       | [ckpt](https://github.com/yjh0410/E2E_FCOS/releases/download/fcos_weight/fcos_e2e_r18_3x_coco.pth) | [log](https://github.com/yjh0410/E2E_FCOS/releases/download/fcos_weight/FCOS-E2E-R18-3x-COCO.txt) |
+
+For **FCOS_RT_R18_3x**, we only use one-to-many assinger to train `FCOS-RT-R18-3x` and evaluate it with NMS.
+
+For **FCOS_RT_R18_3x (O2O)**, we only use one-to-one assinger to train `FCOS-RT-R18-3x` and evaluate it without NMS.
+
+For **FCOS_E2E_R18_3x**, we deploy two parallel detection head, one using one-to-many assinger (o2m head) and the other using one-to-one assinger (o2o head). To avoid conflicts between the gradients returned by o2o head and o2m head, we truncate the gradients returned from o2o head to the backbone and neck, and only allow the gradients returned from o2m head to update the backbone and neck. This operation is consistent with the practice of YOLOv10. For evaluation, we remove the o2m head and only use o2o head without NMS.

+ 18 - 0
odlab/models/detectors/fcos_e2e/build.py

@@ -0,0 +1,18 @@
+from .fcos import FcosE2E
+from .criterion import SetCriterion
+
+
+def build_fcos_e2e(cfg, is_val=False):
+    # ------------ build object detector ------------
+    ## RT-FCOS    
+    model = FcosE2E(cfg          = cfg,
+                    conf_thresh  = cfg.train_conf_thresh if is_val else cfg.test_conf_thresh,
+                    topk_results = cfg.train_topk        if is_val else cfg.test_topk,
+                    )
+    criterion = None
+    if is_val:
+        # build criterion for training
+        criterion = SetCriterion(cfg)
+
+    return model, criterion
+    

+ 171 - 0
odlab/models/detectors/fcos_e2e/criterion.py

@@ -0,0 +1,171 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from utils.box_ops import get_ious
+from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
+
+from .matcher import AlignedOTAMatcher
+
+
+class SetCriterion(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        # ------------- Basic parameters -------------
+        self.cfg = cfg
+        self.num_classes = cfg.num_classes
+        # ------------- Focal loss -------------
+        self.alpha = cfg.focal_loss_alpha
+        self.gamma = cfg.focal_loss_gamma
+        # ------------- Loss weight -------------
+        self.weight_dict = {'loss_cls': cfg.loss_cls_weight,
+                            'loss_reg': cfg.loss_reg_weight}
+        # ------------- Matcher & Loss weight -------------
+        self.matcher_cfg = cfg.matcher_hpy
+        self.matcher = AlignedOTAMatcher(cfg.num_classes,
+                                         cfg.matcher_hpy['soft_center_radius'],
+                                         cfg.matcher_hpy['topk_candidates'],
+                                         )
+
+    def loss_labels(self, pred_cls, target, beta=2.0, num_boxes=1.0):
+        # Quality FocalLoss
+        """
+            pred_cls: (torch.Tensor): [N, C]。
+            target:   (tuple([torch.Tensor], [torch.Tensor])): label -> (N,), score -> (N)
+        """
+        label, score = target
+        pred_sigmoid = pred_cls.sigmoid()
+        scale_factor = pred_sigmoid
+        zerolabel = scale_factor.new_zeros(pred_cls.shape)
+
+        ce_loss = F.binary_cross_entropy_with_logits(
+            pred_cls, zerolabel, reduction='none') * scale_factor.pow(beta)
+        
+        bg_class_ind = pred_cls.shape[-1]
+        pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
+        if pos.shape[0] > 0:
+            pos_label = label[pos].long()
+
+            scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
+
+            ce_loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
+                pred_cls[pos, pos_label], score[pos],
+                reduction='none') * scale_factor.abs().pow(beta)
+
+        return ce_loss.sum() / num_boxes
+    
+    def loss_bboxes(self, pred_box, gt_box, num_boxes=1.0, box_weight=None):
+        ious = get_ious(pred_box, gt_box, box_mode="xyxy", iou_type='giou')
+        loss_box = 1.0 - ious
+
+        if box_weight is not None:
+            loss_box = loss_box.squeeze(-1) * box_weight
+
+        return loss_box.sum() / num_boxes
+
+    def compute_loss(self, outputs, targets):
+        """
+            outputs['pred_cls']: (Tensor) [B, M, C]
+            outputs['pred_reg']: (Tensor) [B, M, 4]
+            outputs['pred_box']: (Tensor) [B, M, 4]
+            outputs['strides']: (List) [8, 16, 32, ...] stride of the model output
+            targets: (List) [dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}, ...]
+        """
+        # -------------------- Pre-process --------------------
+        bs          = outputs['pred_cls'][0].shape[0]
+        device      = outputs['pred_cls'][0].device
+        fpn_strides = outputs['strides']
+        anchors     = outputs['anchors']
+        # preds: [B, M, C]
+        cls_preds = torch.cat(outputs['pred_cls'], dim=1)
+        box_preds = torch.cat(outputs['pred_box'], dim=1)
+        masks = ~torch.cat(outputs['mask'], dim=1).view(-1)
+
+        # -------------------- Label Assignment --------------------
+        cls_targets = []
+        box_targets = []
+        assign_metrics = []
+        for batch_idx in range(bs):
+            tgt_labels = targets[batch_idx]["labels"].to(device)  # [N,]
+            tgt_bboxes = targets[batch_idx]["boxes"].to(device)   # [N, 4]
+            # refine target
+            tgt_boxes_wh = tgt_bboxes[..., 2:] - tgt_bboxes[..., :2]
+            min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
+            keep = (min_tgt_size >= 8)
+            tgt_bboxes = tgt_bboxes[keep]
+            tgt_labels = tgt_labels[keep]
+            # label assignment
+            assigned_result = self.matcher(fpn_strides=fpn_strides,
+                                           anchors=anchors,
+                                           pred_cls=cls_preds[batch_idx].detach(),
+                                           pred_box=box_preds[batch_idx].detach(),
+                                           gt_labels=tgt_labels,
+                                           gt_bboxes=tgt_bboxes
+                                           )
+            cls_targets.append(assigned_result['assigned_labels'])
+            box_targets.append(assigned_result['assigned_bboxes'])
+            assign_metrics.append(assigned_result['assign_metrics'])
+
+        # List[B, M, C] -> Tensor[BM, C]
+        cls_targets = torch.cat(cls_targets, dim=0)
+        box_targets = torch.cat(box_targets, dim=0)
+        assign_metrics = torch.cat(assign_metrics, dim=0)
+
+        valid_idxs = (cls_targets >= 0) & masks
+        foreground_idxs = (cls_targets >= 0) & (cls_targets != self.num_classes)
+        num_fgs = assign_metrics.sum()
+
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_fgs)
+        num_fgs = torch.clamp(num_fgs / get_world_size(), min=1).item()
+
+        # -------------------- Classification loss --------------------
+        cls_preds = cls_preds.view(-1, self.num_classes)[valid_idxs]
+        qfl_targets = (cls_targets[valid_idxs], assign_metrics[valid_idxs])
+        loss_labels = self.loss_labels(cls_preds, qfl_targets, 2.0, num_fgs)
+
+        # -------------------- Regression loss --------------------
+        box_preds_pos = box_preds.view(-1, 4)[foreground_idxs]
+        box_targets_pos = box_targets[foreground_idxs]
+        box_weight = assign_metrics[foreground_idxs]
+        loss_bboxes = self.loss_bboxes(box_preds_pos, box_targets_pos, num_fgs, box_weight)
+
+        total_loss = loss_labels * self.weight_dict["loss_cls"] + \
+                         loss_bboxes * self.weight_dict["loss_reg"]
+        loss_dict = dict(
+                loss_cls = loss_labels,
+                loss_reg = loss_bboxes,
+                loss     = total_loss,
+        )
+
+        return loss_dict
+    
+    def forward(self, outputs, targets):
+        """
+            outputs['pred_cls']: (Tensor) [B, M, C]
+            outputs['pred_reg']: (Tensor) [B, M, 4]
+            outputs['pred_box']: (Tensor) [B, M, 4]
+            outputs['strides']: (List) [8, 16, 32, ...] stride of the model output
+            targets: (List) [dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}, ...]
+        """
+        self.matcher.topk_candidates = self.cfg.matcher_hpy['topk_candidates']
+        o2m_loss_dict = self.compute_loss(outputs["outputs_o2m"], targets)
+        self.matcher.topk_candidates = 1
+        o2o_loss_dict = self.compute_loss(outputs["outputs_o2o"], targets)
+
+        loss_dict = {}
+        loss_dict["loss"] = o2o_loss_dict["loss"] + o2m_loss_dict["loss"]
+        for k in o2m_loss_dict:
+            loss_dict['o2m_' + k] = o2m_loss_dict[k]
+        for k in o2o_loss_dict:
+            loss_dict['o2o_' + k] = o2o_loss_dict[k]
+
+        return loss_dict
+    
+
+if __name__ == "__main__":
+    pass

+ 137 - 0
odlab/models/detectors/fcos_e2e/fcos.py

@@ -0,0 +1,137 @@
+import copy
+import torch
+import torch.nn as nn
+
+# --------------- Model components ---------------
+from ...backbone import build_backbone
+from ...neck import build_neck
+from ...head import build_head
+
+
+# --------------------- End-to-End RT-FCOS ---------------------
+class FcosE2E(nn.Module):
+    def __init__(self, 
+                 cfg,
+                 conf_thresh  :float = 0.05,
+                 topk_results :int   = 1000,
+                 ):
+        super(FcosE2E, self).__init__()
+        # ---------------------- Basic Parameters ----------------------
+        self.conf_thresh  = conf_thresh
+        self.num_classes  = cfg.num_classes
+        self.topk_results = topk_results
+
+        # ---------------------- Network Parameters ----------------------
+        ## Backbone
+        self.backbone, pyramid_feats = build_backbone(cfg)
+
+        ## Neck
+        self.backbone_fpn = build_neck(cfg, pyramid_feats, cfg.head_dim)
+
+        ## Heads (one-to-many)
+        self.detection_head_o2m = build_head(cfg, cfg.head_dim, cfg.head_dim)
+
+        ## Heads (one-to-one)
+        self.detection_head_o2o = copy.deepcopy(self.detection_head_o2m)
+
+    def post_process(self, cls_preds, box_preds):
+        """
+        Input:
+            cls_preds: List(Tensor) [[B, H x W, C], ...]
+            box_preds: List(Tensor) [[B, H x W, 4], ...]
+        """
+        all_scores = []
+        all_labels = []
+        all_bboxes = []
+        
+        for cls_pred_i, box_pred_i in zip(cls_preds, box_preds):
+            cls_pred_i = cls_pred_i[0]
+            box_pred_i = box_pred_i[0]
+            
+            # (H x W x C,)
+            scores_i = cls_pred_i.sigmoid().flatten()
+
+            # Keep top k top scoring indices only.
+            num_topk = min(self.topk_results, box_pred_i.size(0))
+
+            # torch.sort is actually faster than .topk (at least on GPUs)
+            predicted_prob, topk_idxs = scores_i.sort(descending=True)
+            topk_scores = predicted_prob[:num_topk]
+            topk_idxs = topk_idxs[:num_topk]
+
+            # filter out the proposals with low confidence score
+            keep_idxs = topk_scores > self.conf_thresh
+            topk_idxs = topk_idxs[keep_idxs]
+
+            # final scores
+            scores = topk_scores[keep_idxs]
+            # final labels
+            labels = topk_idxs % self.num_classes
+            # final bboxes
+            anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+            bboxes = box_pred_i[anchor_idxs]
+
+            all_scores.append(scores)
+            all_labels.append(labels)
+            all_bboxes.append(bboxes)
+
+        scores = torch.cat(all_scores)
+        labels = torch.cat(all_labels)
+        bboxes = torch.cat(all_bboxes)
+
+        # to cpu & numpy
+        scores = scores.cpu().numpy()
+        labels = labels.cpu().numpy()
+        bboxes = bboxes.cpu().numpy()
+
+        return bboxes, scores, labels
+
+    def inference_o2o(self, src):
+        # ---------------- Backbone ----------------
+        pyramid_feats = self.backbone(src)
+
+        # ---------------- Neck ----------------
+        pyramid_feats = self.backbone_fpn(pyramid_feats)
+
+        # ---------------- Heads ----------------
+        outputs = self.detection_head_o2o(pyramid_feats)
+        cls_pred = outputs["pred_cls"]
+        box_pred = outputs["pred_box"]
+
+        # PostProcess (no NMS)
+        bboxes, scores, labels = self.post_process(cls_pred, box_pred)
+
+        # Normalize bbox
+        bboxes[..., 0::2] /= src.shape[-1]
+        bboxes[..., 1::2] /= src.shape[-2]
+        bboxes = bboxes.clip(0., 1.)
+
+        outputs = {
+            'scores': scores,
+            'labels': labels,
+            'bboxes': bboxes
+        }
+
+        return outputs
+
+    def forward(self, src, src_mask=None):
+        if not self.training:
+            return self.inference_o2o(src)
+        else:
+            # ---------------- Backbone ----------------
+            pyramid_feats = self.backbone(src)
+
+            # ---------------- Neck ----------------
+            pyramid_feats = self.backbone_fpn(pyramid_feats)
+
+            # ---------------- Heads ----------------
+            outputs = {}
+            ## One-to-many detection
+            outputs_o2m = self.detection_head_o2m(pyramid_feats, src_mask)
+            outputs["outputs_o2m"] = outputs_o2m
+            ## One-to-one  detection
+            pyramid_feats_detach = [feat.detach() for feat in pyramid_feats]
+            outputs_o2o = self.detection_head_o2o(pyramid_feats_detach, src_mask)
+            outputs["outputs_o2o"] = outputs_o2o
+
+            return outputs 

+ 148 - 0
odlab/models/detectors/fcos_e2e/matcher.py

@@ -0,0 +1,148 @@
+import torch
+import torch.nn.functional as F
+
+from utils.box_ops import box_iou
+
+
+class AlignedOTAMatcher(object):
+    def __init__(self,
+                 num_classes,
+                 soft_center_radius=3.0,
+                 topk_candidates=13,
+                 ):
+        self.num_classes = num_classes
+        self.soft_center_radius = soft_center_radius
+        self.topk_candidates = topk_candidates
+
+    @torch.no_grad()
+    def __call__(self, 
+                 fpn_strides, 
+                 anchors, 
+                 pred_cls, 
+                 pred_box,
+                 gt_labels,
+                 gt_bboxes):
+        # [M,]
+        strides = torch.cat([torch.ones_like(anchor_i[:, 0]) * stride_i
+                                for stride_i, anchor_i in zip(fpn_strides, anchors)], dim=-1)
+        # List[F, M, 2] -> [M, 2]
+        num_gt = len(gt_labels)
+        anchors = torch.cat(anchors, dim=0)
+
+        # check gt
+        if num_gt == 0 or gt_bboxes.max().item() == 0.:
+            return {
+                'assigned_labels': gt_labels.new_full(pred_cls[..., 0].shape, self.num_classes).long(),
+                'assigned_bboxes': gt_bboxes.new_full(pred_box.shape, 0),
+                'assign_metrics':  gt_bboxes.new_full(pred_cls[..., 0].shape, 0),
+            }
+        
+        # get inside points: [N, M]
+        is_in_gt = self.find_inside_points(gt_bboxes, anchors)
+        valid_mask = is_in_gt.sum(dim=0) > 0  # [M,]
+
+        # ----------------------- Soft center prior -----------------------
+        gt_center = (gt_bboxes[..., :2] + gt_bboxes[..., 2:]) / 2.0
+        distance = (anchors.unsqueeze(0) - gt_center.unsqueeze(1)
+                    ).pow(2).sum(-1).sqrt() / strides.unsqueeze(0)  # [N, M]
+        distance = distance * valid_mask.unsqueeze(0)
+        soft_center_prior = torch.pow(10, distance - self.soft_center_radius)
+
+        # ----------------------- Regression cost -----------------------
+        pair_wise_ious, _ = box_iou(gt_bboxes, pred_box)  # [N, M]
+        pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) * 3.0
+
+        # ----------------------- Classification cost -----------------------
+        ## select the predicted scores corresponded to the gt_labels
+        pairwise_pred_scores = pred_cls.permute(1, 0)  # [M, C] -> [C, M]
+        pairwise_pred_scores = pairwise_pred_scores[gt_labels.long(), :].float()   # [N, M]
+        ## scale factor
+        scale_factor = (pair_wise_ious - pairwise_pred_scores.sigmoid()).abs().pow(2.0)
+        ## cls cost
+        pair_wise_cls_loss = F.binary_cross_entropy_with_logits(
+            pairwise_pred_scores, pair_wise_ious,
+            reduction="none") * scale_factor # [N, M]
+            
+        del pairwise_pred_scores
+
+        ## foreground cost matrix
+        cost_matrix = pair_wise_cls_loss + pair_wise_ious_loss + soft_center_prior
+        max_pad_value = torch.ones_like(cost_matrix) * 1e9
+        cost_matrix = torch.where(valid_mask[None].repeat(num_gt, 1),   # [N, M]
+                                  cost_matrix, max_pad_value)
+
+        # ----------------------- Dynamic label assignment -----------------------
+        matched_pred_ious, matched_gt_inds, fg_mask_inboxes = self.dynamic_k_matching(
+            cost_matrix, pair_wise_ious, num_gt)
+        del pair_wise_cls_loss, cost_matrix, pair_wise_ious, pair_wise_ious_loss
+
+        # ----------------------- Process assigned labels -----------------------
+        assigned_labels = gt_labels.new_full(pred_cls[..., 0].shape,
+                                             self.num_classes)  # [M,]
+        assigned_labels[fg_mask_inboxes] = gt_labels[matched_gt_inds].squeeze(-1)
+        assigned_labels = assigned_labels.long()  # [M,]
+
+        assigned_bboxes = gt_bboxes.new_full(pred_box.shape, 0)        # [M, 4]
+        assigned_bboxes[fg_mask_inboxes] = gt_bboxes[matched_gt_inds]  # [M, 4]
+
+        assign_metrics = gt_bboxes.new_full(pred_cls[..., 0].shape, 0) # [M,]
+        assign_metrics[fg_mask_inboxes] = matched_pred_ious            # [M,]
+
+        assigned_dict = dict(
+            assigned_labels=assigned_labels,
+            assigned_bboxes=assigned_bboxes,
+            assign_metrics=assign_metrics
+            )
+        
+        return assigned_dict
+
+    def find_inside_points(self, gt_bboxes, anchors):
+        """
+            gt_bboxes: Tensor -> [N, 2]
+            anchors:   Tensor -> [M, 2]
+        """
+        num_anchors = anchors.shape[0]
+        num_gt = gt_bboxes.shape[0]
+
+        anchors_expand = anchors.unsqueeze(0).repeat(num_gt, 1, 1)           # [N, M, 2]
+        gt_bboxes_expand = gt_bboxes.unsqueeze(1).repeat(1, num_anchors, 1)  # [N, M, 4]
+
+        # offset
+        lt = anchors_expand - gt_bboxes_expand[..., :2]
+        rb = gt_bboxes_expand[..., 2:] - anchors_expand
+        bbox_deltas = torch.cat([lt, rb], dim=-1)
+
+        is_in_gts = bbox_deltas.min(dim=-1).values > 0
+
+        return is_in_gts
+    
+    def dynamic_k_matching(self, cost_matrix, pairwise_ious, num_gt):
+        matching_matrix = torch.zeros_like(cost_matrix, dtype=torch.uint8)
+        # select candidate topk ious for dynamic-k calculation
+        candidate_topk = min(self.topk_candidates, pairwise_ious.size(1))
+        topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=1)
+        # calculate dynamic k for each gt
+        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
+
+        # sorting the batch cost matirx is faster than topk
+        _, sorted_indices = torch.sort(cost_matrix, dim=1)
+        for gt_idx in range(num_gt):
+            topk_ids = sorted_indices[gt_idx, :dynamic_ks[gt_idx]]
+            matching_matrix[gt_idx, :][topk_ids] = 1
+
+        del topk_ious, dynamic_ks, topk_ids
+
+        prior_match_gt_mask = matching_matrix.sum(0) > 1
+        if prior_match_gt_mask.sum() > 0:
+            cost_min, cost_argmin = torch.min(
+                cost_matrix[:, prior_match_gt_mask], dim=0)
+            matching_matrix[:, prior_match_gt_mask] *= 0
+            matching_matrix[cost_argmin, prior_match_gt_mask] = 1
+
+        # get foreground mask inside box and center prior
+        fg_mask_inboxes = matching_matrix.sum(0) > 0
+        matched_pred_ious = (matching_matrix *
+                             pairwise_ious).sum(0)[fg_mask_inboxes]
+        matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
+
+        return matched_pred_ious, matched_gt_inds, fg_mask_inboxes