Explorar o código

modify loss_qfl

yjh0410 %!s(int64=2) %!d(string=hai) anos
pai
achega
e4ed9a9b48

+ 42 - 105
config/model_config/rtcdet_config.py

@@ -46,23 +46,14 @@ rtcdet_cfg = {
         # ---------------- Assignment config ----------------
         ## Matcher
         'matcher': "simota",
-        'matcher_hpy': {"simota": {'center_sampling_radius': 2.5,
-                                   'topk_candidate': 10},
-                        "aligned_simota": {'soft_center_radius': 3.0,
-                                           'topk_candicate': 10,
-                                           'iou_weight': 3.0},
-                                           },
+        'matcher_hpy': {'center_sampling_radius': 2.5,
+                        'topk_candidate': 10},
         # ---------------- Loss config ----------------
-        ## Loss weight
-        'ema_update': False,
+        'cls_loss': 'bce',
+        'loss_cls_weight': 1.0,
+        'loss_dfl_weight': 1.0,
+        'loss_box_weight': 5.0,
         'loss_box_aux': True,
-        'loss_weights': {"simota": {'loss_cls_weight': 1.0,
-                                    'loss_dfl_weight': 1.0,
-                                    'loss_box_weight': 5.0},
-                         "aligned_simota": {'loss_cls_weight': 1.0,
-                                            'loss_dfl_weight': 1.0,
-                                            'loss_box_weight': 2.0}
-                                            },
         # ---------------- Train config ----------------
         'trainer_type': 'rtcdet',
     },
@@ -111,23 +102,14 @@ rtcdet_cfg = {
         # ---------------- Assignment config ----------------
         ## Matcher
         'matcher': "simota",
-        'matcher_hpy': {"simota": {'center_sampling_radius': 2.5,
-                                   'topk_candidate': 10},
-                        "aligned_simota": {'soft_center_radius': 3.0,
-                                           'topk_candicate': 10,
-                                           'iou_weight': 3.0},
-                                           },
+        'matcher_hpy': {'center_sampling_radius': 2.5,
+                        'topk_candidate': 10},
         # ---------------- Loss config ----------------
-        ## Loss weight
-        'ema_update': False,
+        'cls_loss': 'qfl',
+        'loss_cls_weight': 1.0,
+        'loss_dfl_weight': 1.0,
+        'loss_box_weight': 5.0,
         'loss_box_aux': True,
-        'loss_weights': {"simota": {'loss_cls_weight': 1.0,
-                                    'loss_dfl_weight': 1.0,
-                                    'loss_box_weight': 5.0},
-                         "aligned_simota": {'loss_cls_weight': 1.0,
-                                            'loss_dfl_weight': 1.0,
-                                            'loss_box_weight': 2.0}
-                                            },
         # ---------------- Train config ----------------
         'trainer_type': 'rtcdet',
     },
@@ -176,23 +158,14 @@ rtcdet_cfg = {
         # ---------------- Assignment config ----------------
         ## Matcher
         'matcher': "simota",
-        'matcher_hpy': {"simota": {'center_sampling_radius': 2.5,
-                                   'topk_candidate': 10},
-                        "aligned_simota": {'soft_center_radius': 3.0,
-                                           'topk_candicate': 10,
-                                           'iou_weight': 3.0},
-                                           },
+        'matcher_hpy': {'center_sampling_radius': 2.5,
+                        'topk_candidate': 10},
         # ---------------- Loss config ----------------
-        ## Loss weight
-        'ema_update': False,
+        'cls_loss': 'bce',
+        'loss_cls_weight': 1.0,
+        'loss_dfl_weight': 1.0,
+        'loss_box_weight': 5.0,
         'loss_box_aux': True,
-        'loss_weights': {"simota": {'loss_cls_weight': 1.0,
-                                    'loss_dfl_weight': 1.0,
-                                    'loss_box_weight': 5.0},
-                         "aligned_simota": {'loss_cls_weight': 1.0,
-                                            'loss_dfl_weight': 1.0,
-                                            'loss_box_weight': 2.0}
-                                            },
         # ---------------- Train config ----------------
         'trainer_type': 'rtcdet',
     },
@@ -241,23 +214,14 @@ rtcdet_cfg = {
         # ---------------- Assignment config ----------------
         ## Matcher
         'matcher': "simota",
-        'matcher_hpy': {"simota": {'center_sampling_radius': 2.5,
-                                   'topk_candidate': 10},
-                        "aligned_simota": {'soft_center_radius': 3.0,
-                                           'topk_candicate': 10,
-                                           'iou_weight': 3.0},
-                                           },
+        'matcher_hpy': {'center_sampling_radius': 2.5,
+                        'topk_candidate': 10},
         # ---------------- Loss config ----------------
-        ## Loss weight
-        'ema_update': False,
+        'cls_loss': 'bce',
+        'loss_cls_weight': 1.0,
+        'loss_dfl_weight': 1.0,
+        'loss_box_weight': 5.0,
         'loss_box_aux': True,
-        'loss_weights': {"simota": {'loss_cls_weight': 1.0,
-                                    'loss_dfl_weight': 1.0,
-                                    'loss_box_weight': 5.0},
-                         "aligned_simota": {'loss_cls_weight': 1.0,
-                                            'loss_dfl_weight': 1.0,
-                                            'loss_box_weight': 2.0}
-                                            },
         # ---------------- Train config ----------------
         'trainer_type': 'rtcdet',
     },
@@ -306,23 +270,14 @@ rtcdet_cfg = {
         # ---------------- Assignment config ----------------
         ## Matcher
         'matcher': "simota",
-        'matcher_hpy': {"simota": {'center_sampling_radius': 2.5,
-                                   'topk_candidate': 10},
-                        "aligned_simota": {'soft_center_radius': 3.0,
-                                           'topk_candicate': 10,
-                                           'iou_weight': 3.0},
-                                           },
+        'matcher_hpy': {'center_sampling_radius': 2.5,
+                        'topk_candidate': 10},
         # ---------------- Loss config ----------------
-        ## Loss weight
-        'ema_update': False,
+        'cls_loss': 'bce',
+        'loss_cls_weight': 1.0,
+        'loss_dfl_weight': 1.0,
+        'loss_box_weight': 5.0,
         'loss_box_aux': True,
-        'loss_weights': {"simota": {'loss_cls_weight': 1.0,
-                                    'loss_dfl_weight': 1.0,
-                                    'loss_box_weight': 5.0},
-                         "aligned_simota": {'loss_cls_weight': 1.0,
-                                            'loss_dfl_weight': 1.0,
-                                            'loss_box_weight': 2.0}
-                                            },
         # ---------------- Train config ----------------
         'trainer_type': 'rtcdet',
     },
@@ -371,23 +326,14 @@ rtcdet_cfg = {
         # ---------------- Assignment config ----------------
         ## Matcher
         'matcher': "simota",
-        'matcher_hpy': {"simota": {'center_sampling_radius': 2.5,
-                                   'topk_candidate': 10},
-                        "aligned_simota": {'soft_center_radius': 3.0,
-                                           'topk_candicate': 10,
-                                           'iou_weight': 3.0},
-                                           },
+        'matcher_hpy': {'center_sampling_radius': 2.5,
+                        'topk_candidate': 10},
         # ---------------- Loss config ----------------
-        ## Loss weight
-        'ema_update': False,
+        'cls_loss': 'bce',
+        'loss_cls_weight': 1.0,
+        'loss_dfl_weight': 1.0,
+        'loss_box_weight': 5.0,
         'loss_box_aux': True,
-        'loss_weights': {"simota": {'loss_cls_weight': 1.0,
-                                    'loss_dfl_weight': 1.0,
-                                    'loss_box_weight': 5.0},
-                         "aligned_simota": {'loss_cls_weight': 1.0,
-                                            'loss_dfl_weight': 1.0,
-                                            'loss_box_weight': 2.0}
-                                            },
         # ---------------- Train config ----------------
         'trainer_type': 'rtcdet',
     },
@@ -436,23 +382,14 @@ rtcdet_cfg = {
         # ---------------- Assignment config ----------------
         ## Matcher
         'matcher': "simota",
-        'matcher_hpy': {"simota": {'center_sampling_radius': 2.5,
-                                   'topk_candidate': 10},
-                        "aligned_simota": {'soft_center_radius': 3.0,
-                                           'topk_candicate': 10,
-                                           'iou_weight': 3.0},
-                                           },
+        'matcher_hpy': {'center_sampling_radius': 2.5,
+                        'topk_candidate': 10},
         # ---------------- Loss config ----------------
-        ## Loss weight
-        'ema_update': False,
+        'cls_loss': 'bce',
+        'loss_cls_weight': 1.0,
+        'loss_dfl_weight': 1.0,
+        'loss_box_weight': 5.0,
         'loss_box_aux': True,
-        'loss_weights': {"simota": {'loss_cls_weight': 1.0,
-                                    'loss_dfl_weight': 1.0,
-                                    'loss_box_weight': 5.0},
-                         "aligned_simota": {'loss_cls_weight': 1.0,
-                                            'loss_dfl_weight': 1.0,
-                                            'loss_box_weight': 2.0}
-                                            },
         # ---------------- Train config ----------------
         'trainer_type': 'rtcdet',
     },

+ 84 - 91
models/detectors/rtcdet/loss.py

@@ -4,7 +4,7 @@ import torch.nn.functional as F
 from utils.box_ops import bbox2dist, get_ious
 from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
 
-from .matcher import build_matcher
+from .matcher import SimOTA
 
 
 # ----------------------- Criterion for training -----------------------
@@ -16,25 +16,17 @@ class Criterion(object):
         self.num_classes = num_classes
         self.max_epoch = args.max_epoch
         self.no_aug_epoch = args.no_aug_epoch
-        self.use_ema_update = cfg['ema_update']
-        self.loss_box_aux    = cfg['loss_box_aux']
         # ---------------- Loss weight ----------------
-        loss_weights = cfg['loss_weights'][cfg['matcher']]
-        self.loss_cls_weight = loss_weights['loss_cls_weight']
-        self.loss_box_weight = loss_weights['loss_box_weight']
-        self.loss_dfl_weight = loss_weights['loss_dfl_weight']
+        self.loss_box_aux    = cfg['loss_box_aux']
+        self.loss_cls_weight = cfg['loss_cls_weight']
+        self.loss_box_weight = cfg['loss_box_weight']
+        self.loss_dfl_weight = cfg['loss_dfl_weight']
         # ---------------- Matcher ----------------
         ## Aligned SimOTA assigner
-        self.matcher = build_matcher(cfg, num_classes)
-
-    def ema_update(self, name: str, value, initial_value, momentum=0.9):
-        if hasattr(self, name):
-            old = getattr(self, name)
-        else:
-            old = initial_value
-        new = old * momentum + value * (1 - momentum)
-        setattr(self, name, new)
-        return new
+        self.matcher_hpy = cfg['matcher_hpy']
+        self.matcher = SimOTA(num_classes            = num_classes,
+                              center_sampling_radius = self.matcher_hpy['center_sampling_radius'],
+                              topk_candidate         = self.matcher_hpy['topk_candidate'])
 
     # ----------------- Loss functions -----------------
     def loss_classes(self, pred_cls, gt_score):
@@ -117,7 +109,7 @@ class Criterion(object):
         return loss_box_aux
     
     # ----------------- Main process -----------------
-    def loss_simota(self, outputs, targets, epoch=0):
+    def compute_loss1(self, outputs, targets, epoch=0):
         bs = outputs['pred_cls'][0].shape[0]
         device = outputs['pred_cls'][0].device
         fpn_strides = outputs['strides']
@@ -177,22 +169,16 @@ class Criterion(object):
         if is_dist_avail_and_initialized():
             torch.distributed.all_reduce(num_fgs)
         num_fgs = (num_fgs / get_world_size()).clamp(1.0)
-
-        # update loss normalizer with EMA
-        if self.use_ema_update:
-            normalizer = self.ema_update("loss_normalizer", max(num_fgs, 1), 100)
-        else:
-            normalizer = num_fgs
         
         # ------------------ Classification loss ------------------
         cls_preds = cls_preds.view(-1, self.num_classes)
         loss_cls = self.loss_classes(cls_preds, cls_targets)
-        loss_cls = loss_cls.sum() / normalizer
+        loss_cls = loss_cls.sum() / num_fgs
 
         # ------------------ Regression loss ------------------
         box_preds_pos = box_preds.view(-1, 4)[fg_masks]
         loss_box = self.loss_bboxes(box_preds_pos, box_targets)
-        loss_box = loss_box.sum() / normalizer
+        loss_box = loss_box.sum() / num_fgs
 
         # ------------------ Distribution focal loss  ------------------
         ## process anchors
@@ -207,7 +193,7 @@ class Criterion(object):
         strides_pos = strides[fg_masks]
         ## compute dfl
         loss_dfl = self.loss_dfl(reg_preds_pos, box_targets, anchors_pos, strides_pos)
-        loss_dfl = loss_dfl.sum() / normalizer
+        loss_dfl = loss_dfl.sum() / num_fgs
 
         # total loss
         losses = self.loss_cls_weight * loss_cls + \
@@ -228,7 +214,7 @@ class Criterion(object):
             delta_preds_pos = delta_preds.view(-1, 4)[fg_masks]
             ## aux loss
             loss_box_aux = self.loss_bboxes_aux(delta_preds_pos, box_targets, anchors_pos, strides_pos)
-            loss_box_aux = loss_box_aux.sum() / normalizer
+            loss_box_aux = loss_box_aux.sum() / num_fgs
 
             losses += loss_box_aux
             loss_dict['loss_box_aux'] = loss_box_aux
@@ -236,19 +222,12 @@ class Criterion(object):
 
         return loss_dict
 
-    def loss_aligned_simota(self, outputs, targets, epoch=0):
-        """
-            outputs['pred_cls']: List(Tensor) [B, M, C]
-            outputs['pred_box']: List(Tensor) [B, M, 4]
-            outputs['strides']: List(Int) [8, 16, 32] output stride
-            targets: (List) [dict{'boxes': [...], 
-                                 'labels': [...], 
-                                 'orig_size': ...}, ...]
-        """
+    def compute_loss2(self, outputs, targets, epoch=0):
         bs = outputs['pred_cls'][0].shape[0]
         device = outputs['pred_cls'][0].device
         fpn_strides = outputs['strides']
         anchors = outputs['anchors']
+        num_anchors = sum([ab.shape[0] for ab in anchors])
         # preds: [B, M, C]
         cls_preds = torch.cat(outputs['pred_cls'], dim=1)
         reg_preds = torch.cat(outputs['pred_reg'], dim=1)
@@ -257,54 +236,65 @@ class Criterion(object):
         # --------------- label assignment ---------------
         cls_targets = []
         box_targets = []
-        assign_metrics = []
+        iou_targets = []
+        fg_masks = []
         for batch_idx in range(bs):
-            tgt_labels = targets[batch_idx]["labels"].to(device)  # [N,]
-            tgt_bboxes = targets[batch_idx]["boxes"].to(device)   # [N, 4]
-            # label assignment
-            assigned_result = self.matcher(fpn_strides=fpn_strides,
-                                           anchors=anchors,
-                                           pred_cls=cls_preds[batch_idx].detach(),
-                                           pred_box=box_preds[batch_idx].detach(),
-                                           gt_labels=tgt_labels,
-                                           gt_bboxes=tgt_bboxes
-                                           )
-            cls_targets.append(assigned_result['assigned_labels'])
-            box_targets.append(assigned_result['assigned_bboxes'])
-            assign_metrics.append(assigned_result['assign_metrics'])
-
-        cls_targets = torch.cat(cls_targets, dim=0)
-        box_targets = torch.cat(box_targets, dim=0)
-        assign_metrics = torch.cat(assign_metrics, dim=0)
-
-        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
-        bg_class_ind = self.num_classes
-        pos_inds = ((cls_targets >= 0)
-                    & (cls_targets < bg_class_ind)).nonzero().squeeze(1)
-        num_fgs = assign_metrics.sum()
+            tgt_labels = targets[batch_idx]["labels"].to(device)
+            tgt_bboxes = targets[batch_idx]["boxes"].to(device)
 
+            # check target
+            if len(tgt_labels) == 0 or tgt_bboxes.max().item() == 0.:
+                # There is no valid gt
+                cls_target = cls_preds.new_full([num_anchors], self.num_classes, dtype=torch.long)
+                iou_target = cls_preds.new_zeros([num_anchors])
+                box_target = cls_preds.new_zeros((0, 4))
+                fg_mask = cls_preds.new_zeros(num_anchors).bool()
+            else:
+                (
+                    fg_mask,
+                    assigned_labels,
+                    assigned_ious,
+                    assigned_indexs
+                ) = self.matcher(
+                    fpn_strides = fpn_strides,
+                    anchors = anchors,
+                    pred_cls = cls_preds[batch_idx], 
+                    pred_box = box_preds[batch_idx],
+                    tgt_labels = tgt_labels,
+                    tgt_bboxes = tgt_bboxes
+                    )
+                # prepare cls targets
+                cls_target = assigned_labels.new_full([num_anchors], self.num_classes, dtype=torch.long)
+                cls_target[fg_mask] = assigned_labels
+                iou_target = assigned_ious.new_zeros([num_anchors])
+                iou_target[fg_mask] = assigned_ious
+                # prepare box targets
+                box_target = tgt_bboxes[assigned_indexs]
+
+            cls_targets.append(cls_target)
+            box_targets.append(box_target)
+            iou_targets.append(iou_target)
+            fg_masks.append(fg_mask)
+
+        cls_targets = torch.cat(cls_targets, 0)   # [M,]
+        box_targets = torch.cat(box_targets, 0)   # [M, 4]
+        iou_targets = torch.cat(iou_targets, 0)   # [M,]
+        fg_masks = torch.cat(fg_masks, 0)
+        num_fgs = fg_masks.sum()
+
+        # average loss normalizer across all the GPUs
         if is_dist_avail_and_initialized():
             torch.distributed.all_reduce(num_fgs)
-        num_fgs = (num_fgs / get_world_size()).clamp(1.0).item()
+        num_fgs = (num_fgs / get_world_size()).clamp(1.0)
         
-        # update loss normalizer with EMA
-        if self.use_ema_update:
-            normalizer = self.ema_update("loss_normalizer", max(num_fgs, 1), 100)
-        else:
-            normalizer = num_fgs
-
-        # ---------------------------- Classification loss ----------------------------
+        # ------------------ Classification loss ------------------
         cls_preds = cls_preds.view(-1, self.num_classes)
-        loss_cls = self.loss_classes_qfl(cls_preds, (cls_targets, assign_metrics))
-        loss_cls = loss_cls.sum() / normalizer
+        loss_cls = self.loss_classes_qfl(cls_preds, (cls_targets, iou_targets))
+        loss_cls = loss_cls.sum() / num_fgs
 
-        # ---------------------------- Regression loss ----------------------------
-        box_preds_pos = box_preds.view(-1, 4)[pos_inds]
-        box_targets_pos = box_targets[pos_inds]
-        box_weight_pos = assign_metrics[pos_inds]
-        loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos)
-        loss_box *= box_weight_pos
-        loss_box = loss_box.sum() / normalizer
+        # ------------------ Regression loss ------------------
+        loss_box = self.loss_bboxes(box_preds.view(-1, 4)[fg_masks], box_targets)
+        loss_box = loss_box.sum() / num_fgs
 
         # ------------------ Distribution focal loss  ------------------
         ## process anchors
@@ -314,13 +304,12 @@ class Criterion(object):
         strides = torch.cat(outputs['stride_tensor'], dim=0)
         strides = strides.unsqueeze(0).repeat(bs, 1, 1).view(-1, 1)
         ## fg preds
-        reg_preds_pos = reg_preds.view(-1, 4*self.cfg['reg_max'])[pos_inds]
-        anchors_pos = anchors[pos_inds]
-        strides_pos = strides[pos_inds]
+        reg_preds_pos = reg_preds.view(-1, 4*self.cfg['reg_max'])[fg_masks]
+        anchors_pos = anchors[fg_masks]
+        strides_pos = strides[fg_masks]
         ## compute dfl
-        loss_dfl = self.loss_dfl(reg_preds_pos, box_targets_pos, anchors_pos, strides_pos)
-        loss_dfl *= box_weight_pos
-        loss_dfl = loss_dfl.sum() / normalizer
+        loss_dfl = self.loss_dfl(reg_preds_pos, box_targets, anchors_pos, strides_pos)
+        loss_dfl = loss_dfl.sum() / num_fgs
 
         # total loss
         losses = self.loss_cls_weight * loss_cls + \
@@ -338,22 +327,26 @@ class Criterion(object):
         if epoch >= (self.max_epoch - self.no_aug_epoch - 1) and self.loss_box_aux:
             ## delta_preds
             delta_preds = torch.cat(outputs['pred_delta'], dim=1)
-            delta_preds_pos = delta_preds.view(-1, 4)[pos_inds]
+            delta_preds_pos = delta_preds.view(-1, 4)[fg_masks]
             ## aux loss
-            loss_box_aux = self.loss_bboxes_aux(delta_preds_pos, box_targets_pos, anchors_pos, strides_pos)
-            loss_box_aux = loss_box_aux.sum() / normalizer
+            loss_box_aux = self.loss_bboxes_aux(delta_preds_pos, box_targets, anchors_pos, strides_pos)
+            loss_box_aux = loss_box_aux.sum() / num_fgs
 
             losses += loss_box_aux
             loss_dict['loss_box_aux'] = loss_box_aux
 
+
         return loss_dict
 
 
     def __call__(self, outputs, targets, epoch=0):
-        if self.cfg['matcher'] == "simota":
-            return self.loss_simota(outputs, targets, epoch)
-        elif self.cfg['matcher'] == "aligned_simota":
-            return self.loss_aligned_simota(outputs, targets, epoch)
+        if self.cfg['cls_loss'] == "bce":
+            return self.compute_loss1(outputs, targets, epoch)
+        elif self.cfg['cls_loss'] == "qfl":
+            self.loss_box_weight = 2.0
+            return self.compute_loss2(outputs, targets, epoch)
+        else:
+            raise NotImplementedError
         
 
 def build_criterion(args, cfg, device, num_classes):

+ 0 - 188
models/detectors/rtcdet/matcher.py

@@ -178,191 +178,3 @@ class SimOTA(object):
             fg_mask_inboxes
         ]
         return assigned_labels, assigned_ious, assigned_indexs
-    
-
-# -------------------------- RTMDet's Aligned SimOTA Assigner --------------------------
-## Aligned SimOTA
-class AlignedSimOTA(object):
-    """
-        This code referenced to https://github.com/open-mmlab/mmyolo/models/task_modules/assigners/batch_dsl_assigner.py
-    """
-    def __init__(self, num_classes, soft_center_radius=3.0, topk=13, iou_weight=3.0):
-        self.num_classes = num_classes
-        self.soft_center_radius = soft_center_radius
-        self.topk = topk
-        self.iou_weight = iou_weight
-
-
-    @torch.no_grad()
-    def __call__(self, 
-                 fpn_strides, 
-                 anchors, 
-                 pred_cls, 
-                 pred_box, 
-                 gt_labels,
-                 gt_bboxes):
-        # [M,]
-        strides = torch.cat([torch.ones_like(anchor_i[:, 0]) * stride_i
-                                for stride_i, anchor_i in zip(fpn_strides, anchors)], dim=-1)
-        # List[F, M, 2] -> [M, 2]
-        anchors = torch.cat(anchors, dim=0)
-        num_gt = len(gt_labels)
-
-        # check gt
-        if num_gt == 0 or gt_bboxes.max().item() == 0.:
-            return {
-                'assigned_labels': gt_labels.new_full(pred_cls[..., 0].shape,
-                                                      self.num_classes,
-                                                      dtype=torch.long),
-                'assigned_bboxes': gt_bboxes.new_full(pred_box.shape, 0),
-                'assign_metrics': gt_bboxes.new_full(pred_cls[..., 0].shape, 0)
-            }
-        
-        # get inside points: [N, M]
-        is_in_gt = self.find_inside_points(gt_bboxes, anchors)
-        valid_mask = is_in_gt.sum(dim=0) > 0  # [M,]
-
-        # ----------------------------------- Soft center prior -----------------------------------
-        gt_center = (gt_bboxes[..., :2] + gt_bboxes[..., 2:]) / 2.0
-        distance = (anchors.unsqueeze(0) - gt_center.unsqueeze(1)
-                    ).pow(2).sum(-1).sqrt() / strides.unsqueeze(0)  # [N, M]
-        distance = distance * valid_mask.unsqueeze(0)
-        soft_center_prior = torch.pow(10, distance - self.soft_center_radius)
-
-        # ----------------------------------- Regression cost -----------------------------------
-        pair_wise_ious, _ = box_iou(gt_bboxes, pred_box)  # [N, M]
-        pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) * self.iou_weight
-
-        # ----------------------------------- Classification cost -----------------------------------
-        ## select the predicted scores corresponded to the gt_labels
-        pairwise_pred_scores = pred_cls.permute(1, 0)  # [M, C] -> [C, M]
-        pairwise_pred_scores = pairwise_pred_scores[gt_labels.long(), :].float()   # [N, M]
-        ## scale factor
-        scale_factor = (pair_wise_ious - pairwise_pred_scores.sigmoid()).abs().pow(2.0)
-        ## cls cost
-        pair_wise_cls_loss = F.binary_cross_entropy_with_logits(
-            pairwise_pred_scores, pair_wise_ious,
-            reduction="none") * scale_factor # [N, M]
-            
-        del pairwise_pred_scores
-
-        ## foreground cost matrix
-        cost_matrix = pair_wise_cls_loss + pair_wise_ious_loss + soft_center_prior
-        max_pad_value = torch.ones_like(cost_matrix) * 1e9
-        cost_matrix = torch.where(valid_mask[None].repeat(num_gt, 1),   # [N, M]
-                                  cost_matrix, max_pad_value)
-
-        # ----------------------------------- dynamic label assignment -----------------------------------
-        (
-            matched_pred_ious,
-            matched_gt_inds,
-            fg_mask_inboxes
-        ) = self.dynamic_k_matching(
-            cost_matrix,
-            pair_wise_ious,
-            num_gt
-            )
-        del pair_wise_cls_loss, cost_matrix, pair_wise_ious, pair_wise_ious_loss
-
-        # -----------------------------------process assigned labels -----------------------------------
-        assigned_labels = gt_labels.new_full(pred_cls[..., 0].shape,
-                                             self.num_classes)  # [M,]
-        assigned_labels[fg_mask_inboxes] = gt_labels[matched_gt_inds].squeeze(-1)
-        assigned_labels = assigned_labels.long()  # [M,]
-
-        assigned_bboxes = gt_bboxes.new_full(pred_box.shape, 0)        # [M, 4]
-        assigned_bboxes[fg_mask_inboxes] = gt_bboxes[matched_gt_inds]  # [M, 4]
-
-        assign_metrics = gt_bboxes.new_full(pred_cls[..., 0].shape, 0) # [M, 4]
-        assign_metrics[fg_mask_inboxes] = matched_pred_ious            # [M, 4]
-
-        assigned_dict = dict(
-            assigned_labels=assigned_labels,
-            assigned_bboxes=assigned_bboxes,
-            assign_metrics=assign_metrics
-            )
-        
-        return assigned_dict
-
-
-    def find_inside_points(self, gt_bboxes, anchors):
-        """
-            gt_bboxes: Tensor -> [N, 2]
-            anchors:   Tensor -> [M, 2]
-        """
-        num_anchors = anchors.shape[0]
-        num_gt = gt_bboxes.shape[0]
-
-        anchors_expand = anchors.unsqueeze(0).repeat(num_gt, 1, 1)           # [N, M, 2]
-        gt_bboxes_expand = gt_bboxes.unsqueeze(1).repeat(1, num_anchors, 1)  # [N, M, 4]
-
-        # offset
-        lt = anchors_expand - gt_bboxes_expand[..., :2]
-        rb = gt_bboxes_expand[..., 2:] - anchors_expand
-        bbox_deltas = torch.cat([lt, rb], dim=-1)
-
-        is_in_gts = bbox_deltas.min(dim=-1).values > 0
-
-        return is_in_gts
-    
-
-    def dynamic_k_matching(self, cost_matrix, pairwise_ious, num_gt):
-        """Use IoU and matching cost to calculate the dynamic top-k positive
-        targets.
-
-        Args:
-            cost_matrix (Tensor): Cost matrix.
-            pairwise_ious (Tensor): Pairwise iou matrix.
-            num_gt (int): Number of gt.
-            valid_mask (Tensor): Mask for valid bboxes.
-        Returns:
-            tuple: matched ious and gt indexes.
-        """
-        matching_matrix = torch.zeros_like(cost_matrix, dtype=torch.uint8)
-        # select candidate topk ious for dynamic-k calculation
-        candidate_topk = min(self.topk, pairwise_ious.size(1))
-        topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=1)
-        # calculate dynamic k for each gt
-        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
-
-        # sorting the batch cost matirx is faster than topk
-        _, sorted_indices = torch.sort(cost_matrix, dim=1)
-        for gt_idx in range(num_gt):
-            topk_ids = sorted_indices[gt_idx, :dynamic_ks[gt_idx]]
-            matching_matrix[gt_idx, :][topk_ids] = 1
-
-        del topk_ious, dynamic_ks, topk_ids
-
-        prior_match_gt_mask = matching_matrix.sum(0) > 1
-        if prior_match_gt_mask.sum() > 0:
-            cost_min, cost_argmin = torch.min(
-                cost_matrix[:, prior_match_gt_mask], dim=0)
-            matching_matrix[:, prior_match_gt_mask] *= 0
-            matching_matrix[cost_argmin, prior_match_gt_mask] = 1
-
-        # get foreground mask inside box and center prior
-        fg_mask_inboxes = matching_matrix.sum(0) > 0
-        matched_pred_ious = (matching_matrix *
-                             pairwise_ious).sum(0)[fg_mask_inboxes]
-        matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
-
-        return matched_pred_ious, matched_gt_inds, fg_mask_inboxes
-
-
-def build_matcher(cfg, num_classes):
-    if cfg['matcher'] == "simota":
-        matcher = SimOTA(
-            center_sampling_radius=cfg['matcher_hpy'][cfg['matcher']]['center_sampling_radius'],
-            topk_candidate=cfg['matcher_hpy'][cfg['matcher']]['topk_candidate'],
-            num_classes=num_classes
-        )
-
-    elif cfg['matcher'] == "aligned_simota":
-        matcher = AlignedSimOTA(
-            num_classes=num_classes,
-            soft_center_radius=cfg['matcher_hpy'][cfg['matcher']]['soft_center_radius'],
-            topk=cfg['matcher_hpy'][cfg['matcher']]['topk_candicate'],
-            iou_weight=cfg['matcher_hpy'][cfg['matcher']]['iou_weight']
-            )
-
-    return matcher

+ 236 - 0
models/detectors/rtcdetv2/loss.py

@@ -0,0 +1,236 @@
+import torch
+import torch.nn.functional as F
+
+from utils.box_ops import bbox2dist, get_ious
+from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
+
+from .matcher import AlignedSimOTA
+
+
+# ----------------------- Criterion for training -----------------------
+class Criterion(object):
+    def __init__(self, args, cfg, device, num_classes=80):
+        self.cfg = cfg
+        self.args = args
+        self.device = device
+        self.num_classes = num_classes
+        self.max_epoch = args.max_epoch
+        self.no_aug_epoch = args.no_aug_epoch
+        # ---------------- Loss weight ----------------
+        self.loss_box_aux    = cfg['loss_box_aux']
+        self.loss_cls_weight = cfg['loss_cls_weight']
+        self.loss_box_weight = cfg['loss_box_weight']
+        self.loss_dfl_weight = cfg['loss_dfl_weight']
+        # ---------------- Matcher ----------------
+        self.matcher_hpy = cfg["matcher_hpy"]
+        self.matcher = AlignedSimOTA(
+            num_classes            = num_classes,
+            center_sampling_radius = self.matcher_hpy['center_sampling_radius'],
+            topk_candidates        = self.matcher_hpy['topk_candidates']
+            )
+
+    # ----------------- Loss functions -----------------
+    def loss_classes(self, pred_cls, target, beta=2.0):
+        # Quality FocalLoss
+        """
+            pred_cls: (torch.Tensor): [N, C]。
+            target:   (tuple([torch.Tensor], [torch.Tensor])): label -> (N,), score -> (N,)
+        """
+        label, score = target
+        pred_sigmoid = pred_cls.sigmoid()
+        scale_factor = pred_sigmoid
+        zerolabel = scale_factor.new_zeros(pred_cls.shape)
+
+        ce_loss = F.binary_cross_entropy_with_logits(
+            pred_cls, zerolabel, reduction='none') * scale_factor.pow(beta)
+        
+        bg_class_ind = pred_cls.shape[-1]
+        pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
+        pos_label = label[pos].long()
+
+        scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
+
+        ce_loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
+            pred_cls[pos, pos_label], score[pos],
+            reduction='none') * scale_factor.abs().pow(beta)
+
+        return ce_loss
+    
+    def loss_bboxes(self, pred_box, gt_box):
+        # regression loss
+        ious = get_ious(pred_box, gt_box, 'xyxy', 'giou')
+        loss_box = 1.0 - ious
+
+        return loss_box
+    
+    def loss_dfl(self, pred_reg, gt_box, anchor, stride, bbox_weight=None):
+        # rescale coords by stride
+        gt_box_s = gt_box / stride
+        anchor_s = anchor / stride
+
+        # compute deltas
+        gt_ltrb_s = bbox2dist(anchor_s, gt_box_s, self.cfg['reg_max'] - 1)
+
+        gt_left = gt_ltrb_s.to(torch.long)
+        gt_right = gt_left + 1
+
+        weight_left = gt_right.to(torch.float) - gt_ltrb_s
+        weight_right = 1 - weight_left
+
+        # loss left
+        loss_left = F.cross_entropy(
+            pred_reg.view(-1, self.cfg['reg_max']),
+            gt_left.view(-1),
+            reduction='none').view(gt_left.shape) * weight_left
+        # loss right
+        loss_right = F.cross_entropy(
+            pred_reg.view(-1, self.cfg['reg_max']),
+            gt_right.view(-1),
+            reduction='none').view(gt_left.shape) * weight_right
+
+        loss_dfl = (loss_left + loss_right).mean(-1)
+        
+        if bbox_weight is not None:
+            loss_dfl *= bbox_weight
+
+        return loss_dfl
+
+    def loss_bboxes_aux(self, pred_delta, gt_box, anchors, stride_tensors):
+        gt_delta_tl = (anchors - gt_box[..., :2]) / stride_tensors
+        gt_delta_rb = (gt_box[..., 2:] - anchors) / stride_tensors
+        gt_delta = torch.cat([gt_delta_tl, gt_delta_rb], dim=1)
+        loss_box_aux = F.l1_loss(pred_delta, gt_delta, reduction='none')
+
+        return loss_box_aux
+    
+    # ----------------- Main process -----------------
+    def __call__(self, outputs, targets, epoch=0):
+        bs = outputs['pred_cls'][0].shape[0]
+        device = outputs['pred_cls'][0].device
+        fpn_strides = outputs['strides']
+        anchors = outputs['anchors']
+        num_anchors = sum([ab.shape[0] for ab in anchors])
+        # preds: [B, M, C]
+        cls_preds = torch.cat(outputs['pred_cls'], dim=1)
+        reg_preds = torch.cat(outputs['pred_reg'], dim=1)
+        box_preds = torch.cat(outputs['pred_box'], dim=1)
+
+        # --------------- label assignment ---------------
+        cls_targets = []
+        box_targets = []
+        iou_targets = []
+        fg_masks = []
+        for batch_idx in range(bs):
+            tgt_labels = targets[batch_idx]["labels"].to(device)
+            tgt_bboxes = targets[batch_idx]["boxes"].to(device)
+
+            # check target
+            if len(tgt_labels) == 0 or tgt_bboxes.max().item() == 0.:
+                # There is no valid gt
+                cls_target = cls_preds.new_full((num_anchors, self.num_classes))
+                iou_target = cls_preds.new_zeros((num_anchors,))
+                box_target = cls_preds.new_zeros((0, 4))
+                fg_mask = cls_preds.new_zeros(num_anchors).bool()
+            else:
+                (
+                    fg_mask,
+                    assigned_labels,
+                    assigned_ious,
+                    assigned_indexs
+                ) = self.matcher(
+                    fpn_strides = fpn_strides,
+                    anchors = anchors,
+                    pred_cls = cls_preds[batch_idx], 
+                    pred_box = box_preds[batch_idx],
+                    tgt_labels = tgt_labels,
+                    tgt_bboxes = tgt_bboxes
+                    )
+                # prepare cls targets
+                cls_target = assigned_labels.new_full((num_anchors, self.num_classes))
+                cls_target[fg_mask] = assigned_labels
+                iou_target = assigned_labels.new_zero((num_anchors))
+                iou_target[fg_mask] = assigned_ious
+                # prepare box targets
+                box_target = tgt_bboxes[assigned_indexs]
+
+            cls_targets.append(cls_target)
+            box_targets.append(box_target)
+            iou_targets.append(iou_target)
+            fg_masks.append(fg_mask)
+
+        cls_targets = torch.cat(cls_targets, 0)   # [M,]
+        box_targets = torch.cat(box_targets, 0)   # [M, 4]
+        iou_targets = torch.cat(iou_targets, 0)   # [M,]
+        fg_masks = torch.cat(fg_masks, 0)
+        num_fgs = fg_masks.sum()
+
+        # average loss normalizer across all the GPUs
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_fgs)
+        num_fgs = (num_fgs / get_world_size()).clamp(1.0)
+        
+        # ------------------ Classification loss ------------------
+        cls_preds = cls_preds.view(-1, self.num_classes)
+        loss_cls = self.loss_classes(cls_preds, (cls_targets, iou_targets))
+        loss_cls = loss_cls.sum() / num_fgs
+
+        # ------------------ Regression loss ------------------
+        loss_box = self.loss_bboxes(box_preds.view(-1, 4)[fg_masks], box_targets)
+        loss_box = loss_box.sum() / num_fgs
+
+        # ------------------ Distribution focal loss  ------------------
+        ## process anchors
+        anchors = torch.cat(anchors, dim=0)
+        anchors = anchors[None].repeat(bs, 1, 1).view(-1, 2)
+        ## process stride tensors
+        strides = torch.cat(outputs['stride_tensor'], dim=0)
+        strides = strides.unsqueeze(0).repeat(bs, 1, 1).view(-1, 1)
+        ## fg preds
+        reg_preds_pos = reg_preds.view(-1, 4*self.cfg['reg_max'])[fg_masks]
+        anchors_pos = anchors[fg_masks]
+        strides_pos = strides[fg_masks]
+        ## compute dfl
+        loss_dfl = self.loss_dfl(reg_preds_pos, box_targets, anchors_pos, strides_pos)
+        loss_dfl = loss_dfl.sum() / num_fgs
+
+        # total loss
+        losses = self.loss_cls_weight * loss_cls + \
+                 self.loss_box_weight * loss_box + \
+                 self.loss_dfl_weight * loss_dfl
+
+        loss_dict = dict(
+                loss_cls = loss_cls,
+                loss_box = loss_box,
+                loss_dfl = loss_dfl,
+                losses = losses
+        )
+
+        # ------------------ Aux regression loss ------------------
+        if epoch >= (self.max_epoch - self.no_aug_epoch - 1) and self.loss_box_aux:
+            ## delta_preds
+            delta_preds = torch.cat(outputs['pred_delta'], dim=1)
+            delta_preds_pos = delta_preds.view(-1, 4)[fg_masks]
+            ## aux loss
+            loss_box_aux = self.loss_bboxes_aux(delta_preds_pos, box_targets, anchors_pos, strides_pos)
+            loss_box_aux = loss_box_aux.sum() / num_fgs
+
+            losses += loss_box_aux
+            loss_dict['loss_box_aux'] = loss_box_aux
+
+
+        return loss_dict
+
+
+def build_criterion(args, cfg, device, num_classes):
+    criterion = Criterion(
+        args=args,
+        cfg=cfg,
+        device=device,
+        num_classes=num_classes
+        )
+
+    return criterion
+
+
+if __name__ == "__main__":
+    pass

+ 181 - 0
models/detectors/rtcdetv2/matcher.py

@@ -0,0 +1,181 @@
+# ----------------------------------------------------------------------------------------------------------------------
+    # This code referenced to https://github.com/open-mmlab/mmyolo/models/task_modules/assigners/batch_dsl_assigner.py
+# ----------------------------------------------------------------------------------------------------------------------
+import torch
+import torch.nn.functional as F
+from utils.box_ops import *
+
+
+# -------------------------- RTMDet's Aligned SimOTA Assigner --------------------------
+## Aligned SimOTA
+class AlignedSimOTA(object):
+    def __init__(self, num_classes, center_sampling_radius=2.5, topk_candidates=10):
+        self.num_classes = num_classes
+        self.center_sampling_radius = center_sampling_radius
+        self.topk_candidates = topk_candidates
+
+
+    @torch.no_grad()
+    def __call__(self, 
+                 fpn_strides, 
+                 anchors, 
+                 pred_cls, 
+                 pred_box, 
+                 tgt_labels,
+                 tgt_bboxes):
+        # [M,]
+        strides_tensor = torch.cat([torch.ones_like(anchor_i[:, 0]) * stride_i
+                                for stride_i, anchor_i in zip(fpn_strides, anchors)], dim=-1)
+        # List[F, M, 2] -> [M, 2]
+        anchors = torch.cat(anchors, dim=0)
+        num_anchor = anchors.shape[0]        
+        num_gt = len(tgt_labels)
+
+        # ----------------------- Find inside points -----------------------
+        fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(
+            tgt_bboxes, anchors, strides_tensor, num_anchor, num_gt)
+        cls_preds = pred_cls[fg_mask].float()   # [Mp, C]
+        box_preds = pred_box[fg_mask].float()   # [Mp, 4]
+
+        # ----------------------- Reg cost -----------------------
+        pair_wise_ious, _ = box_iou(tgt_bboxes, box_preds)      # [N, Mp]
+        reg_cost = -torch.log(pair_wise_ious + 1e-8)            # [N, Mp]
+
+        # ----------------------- Cls cost -----------------------
+        with torch.cuda.amp.autocast(enabled=False):
+            # [Mp, C] -> [N, Mp, C]
+            cls_preds_expand = cls_preds.unsqueeze(0).repeat(num_gt, 1, 1)
+            # prepare cls_target
+            cls_targets = F.one_hot(tgt_labels.long(), self.num_classes).float()
+            cls_targets = cls_targets.unsqueeze(1).repeat(1, cls_preds_expand.size(1), 1)
+            cls_targets *= pair_wise_ious.unsqueeze(-1)  # iou-aware
+            # [N, Mp]
+            cls_cost = F.binary_cross_entropy_with_logits(cls_preds_expand, cls_targets, reduction="none").sum(-1)
+        del cls_preds_expand
+
+        #----------------------- Dynamic K-Matching -----------------------
+        cost_matrix = (
+            cls_cost
+            + 3.0 * reg_cost
+            + 100000.0 * (~is_in_boxes_and_center)
+        ) # [N, Mp]
+
+        (
+            assigned_labels,         # [num_fg,]
+            assigned_ious,           # [num_fg,]
+            assigned_indexs,         # [num_fg,]
+        ) = self.dynamic_k_matching(
+            cost_matrix,
+            pair_wise_ious,
+            tgt_labels,
+            num_gt,
+            fg_mask
+            )
+        del cls_cost, cost_matrix, pair_wise_ious, reg_cost
+
+        return fg_mask, assigned_labels, assigned_ious, assigned_indexs
+
+
+    def get_in_boxes_info(
+        self,
+        gt_bboxes,   # [N, 4]
+        anchors,     # [M, 2]
+        strides,     # [M,]
+        num_anchors, # M
+        num_gt,      # N
+        ):
+        # anchor center
+        x_centers = anchors[:, 0]
+        y_centers = anchors[:, 1]
+
+        # [M,] -> [1, M] -> [N, M]
+        x_centers = x_centers.unsqueeze(0).repeat(num_gt, 1)
+        y_centers = y_centers.unsqueeze(0).repeat(num_gt, 1)
+
+        # [N,] -> [N, 1] -> [N, M]
+        gt_bboxes_l = gt_bboxes[:, 0].unsqueeze(1).repeat(1, num_anchors) # x1
+        gt_bboxes_t = gt_bboxes[:, 1].unsqueeze(1).repeat(1, num_anchors) # y1
+        gt_bboxes_r = gt_bboxes[:, 2].unsqueeze(1).repeat(1, num_anchors) # x2
+        gt_bboxes_b = gt_bboxes[:, 3].unsqueeze(1).repeat(1, num_anchors) # y2
+
+        b_l = x_centers - gt_bboxes_l
+        b_r = gt_bboxes_r - x_centers
+        b_t = y_centers - gt_bboxes_t
+        b_b = gt_bboxes_b - y_centers
+        bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
+
+        is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
+        is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
+        # in fixed center
+        center_radius = self.center_sampling_radius
+
+        # [N, 2]
+        gt_centers = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) * 0.5
+        
+        # [1, M]
+        center_radius_ = center_radius * strides.unsqueeze(0)
+
+        gt_bboxes_l = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # x1
+        gt_bboxes_t = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # y1
+        gt_bboxes_r = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # x2
+        gt_bboxes_b = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # y2
+
+        c_l = x_centers - gt_bboxes_l
+        c_r = gt_bboxes_r - x_centers
+        c_t = y_centers - gt_bboxes_t
+        c_b = gt_bboxes_b - y_centers
+        center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
+        is_in_centers = center_deltas.min(dim=-1).values > 0.0
+        is_in_centers_all = is_in_centers.sum(dim=0) > 0
+
+        # in boxes and in centers
+        is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
+
+        is_in_boxes_and_center = (
+            is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
+        )
+        return is_in_boxes_anchor, is_in_boxes_and_center
+    
+    
+    def dynamic_k_matching(
+        self, 
+        cost, 
+        pair_wise_ious, 
+        gt_classes, 
+        num_gt, 
+        fg_mask
+        ):
+        # Dynamic K
+        # ---------------------------------------------------------------
+        matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
+
+        ious_in_boxes_matrix = pair_wise_ious
+        n_candidate_k = min(self.topk_candidates, ious_in_boxes_matrix.size(1))
+        topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
+        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
+        dynamic_ks = dynamic_ks.tolist()
+        for gt_idx in range(num_gt):
+            _, pos_idx = torch.topk(
+                cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
+            )
+            matching_matrix[gt_idx][pos_idx] = 1
+
+        del topk_ious, dynamic_ks, pos_idx
+
+        anchor_matching_gt = matching_matrix.sum(0)
+        if (anchor_matching_gt > 1).sum() > 0:
+            _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
+            matching_matrix[:, anchor_matching_gt > 1] *= 0
+            matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
+        fg_mask_inboxes = matching_matrix.sum(0) > 0
+
+        fg_mask[fg_mask.clone()] = fg_mask_inboxes
+
+        assigned_indexs = matching_matrix[:, fg_mask_inboxes].argmax(0)
+        assigned_labels = gt_classes[assigned_indexs]
+
+        assigned_ious = (matching_matrix * pair_wise_ious).sum(0)[
+            fg_mask_inboxes
+        ]
+        return assigned_labels, assigned_ious, assigned_indexs
+    

+ 164 - 0
models/detectors/rtcdetv2/rtcdetv2_backbone.py

@@ -0,0 +1,164 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .rtcdetv2_basic import Conv, ResXStage
+except:
+    from rtcdetv2_basic import Conv, ResXStage
+    
+model_urls = {
+    'resxnet_pico':   None,
+    'resxnet_nano':   None,
+    'resxnet_tiny':   None,
+    'resxnet_small':  None,
+    'resxnet_medium': None,
+    'resxnet_large':  None,
+    'resxnet_huge':   None,
+}
+
+# --------------------- ResXNet -----------------------
+class ResXNet(nn.Module):
+    def __init__(self,
+                 embed_dim    = 96,
+                 expand_ratio = 0.25,
+                 ffn_ratio    = 4.0,
+                 num_branches = 4,
+                 num_stages   = [3, 3, 9, 3],
+                 act_type     = 'silu',
+                 norm_type    = 'BN',
+                 depthwise    = False):
+        super(ResXNet, self).__init__()
+        # ------------------ Basic parameters ------------------
+        self.embed_dim = embed_dim
+        self.expand_ratio = expand_ratio
+        self.ffn_ratio = ffn_ratio
+        self.num_branches = num_branches
+        self.num_stages = num_stages
+        self.feat_dims = [embed_dim * 2, embed_dim * 4, embed_dim * 8]
+        
+        # ------------------ Network parameters ------------------
+        ## P2/4
+        self.layer_1 = nn.Sequential(
+            Conv(3, embed_dim, k=7, p=3, s=2, act_type=act_type, norm_type=norm_type),
+            nn.MaxPool2d((3, 3), stride=2, padding=1)
+        )
+        self.layer_2 = ResXStage(embed_dim, embed_dim, self.expand_ratio, self.ffn_ratio, self.num_branches, self.num_stages[0], True, act_type, norm_type, depthwise)
+        ## P3/8
+        self.layer_3 = nn.Sequential(
+            Conv(embed_dim, embed_dim*2, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),             
+            ResXStage(embed_dim*2, embed_dim*2, self.expand_ratio, self.ffn_ratio, self.num_branches, self.num_stages[1], True, act_type, norm_type, depthwise)
+        )
+        ## P4/16
+        self.layer_4 = nn.Sequential(
+            Conv(embed_dim*2, embed_dim*4, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),             
+            ResXStage(embed_dim*4, embed_dim*4, self.expand_ratio, self.ffn_ratio, self.num_branches, self.num_stages[2], True, act_type, norm_type, depthwise)
+        )
+        ## P5/32
+        self.layer_5 = nn.Sequential(
+            Conv(embed_dim*4, embed_dim*8, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),             
+            ResXStage(embed_dim*8, embed_dim*8, self.expand_ratio, self.ffn_ratio, self.num_branches, self.num_stages[3], True, act_type, norm_type, depthwise)
+        )
+
+    def forward(self, x):
+        c2 = self.layer_1(x)
+        c2 = self.layer_2(c2)
+        c3 = self.layer_3(c2)
+        c4 = self.layer_4(c3)
+        c5 = self.layer_5(c4)
+
+        outputs = [c3, c4, c5]
+
+        return outputs
+
+
+# ---------------------------- Functions ----------------------------
+## load pretrained weight
+def load_weight(model, model_name):
+    # load weight
+    print('Loading pretrained weight ...')
+    url = model_urls[model_name]
+    if url is not None:
+        checkpoint = torch.hub.load_state_dict_from_url(
+            url=url, map_location="cpu", check_hash=True)
+        # checkpoint state dict
+        checkpoint_state_dict = checkpoint.pop("model")
+        # model state dict
+        model_state_dict = model.state_dict()
+        # check
+        for k in list(checkpoint_state_dict.keys()):
+            if k in model_state_dict:
+                shape_model = tuple(model_state_dict[k].shape)
+                shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
+                if shape_model != shape_checkpoint:
+                    checkpoint_state_dict.pop(k)
+            else:
+                checkpoint_state_dict.pop(k)
+                print(k)
+
+        model.load_state_dict(checkpoint_state_dict)
+    else:
+        print('No pretrained for {}'.format(model_name))
+
+    return model
+
+## build ELAN-Net
+def build_backbone(cfg, pretrained=False): 
+    # model
+    backbone = ResXNet(
+        embed_dim=cfg['embed_dim'],
+        expand_ratio=cfg['expand_ratio'],
+        ffn_ratio=cfg['ffn_ratio'],
+        num_branches=cfg['num_branches'],
+        num_stages=cfg['num_stages'],
+        act_type=cfg['bk_act'],
+        norm_type=cfg['bk_norm'],
+        depthwise=cfg['bk_depthwise']
+        )
+    # check whether to load imagenet pretrained weight
+    if pretrained:
+        if cfg['width'] == 0.25 and cfg['depth'] == 0.34 and cfg['bk_depthwise']:
+            backbone = load_weight(backbone, model_name='resxnet_pico')
+        elif cfg['width'] == 0.25 and cfg['depth'] == 0.34:
+            backbone = load_weight(backbone, model_name='resxnet_nano')
+        elif cfg['width'] == 0.375 and cfg['depth'] == 0.34:
+            backbone = load_weight(backbone, model_name='resxnet_tiny')
+        elif cfg['width'] == 0.5 and cfg['depth'] == 0.34:
+            backbone = load_weight(backbone, model_name='resxnet_small')
+        elif cfg['width'] == 0.75 and cfg['depth'] == 0.67:
+            backbone = load_weight(backbone, model_name='resxnet_medium')
+        elif cfg['width'] == 1.0 and cfg['depth'] == 1.0:
+            backbone = load_weight(backbone, model_name='resxnet_large')
+        elif cfg['width'] == 1.25 and cfg['depth'] == 1.34:
+            backbone = load_weight(backbone, model_name='resxnet_huge')
+
+    return backbone, backbone.feat_dims
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        'pretrained': True,
+        'bk_act': 'silu',
+        'bk_norm': 'BN',
+        'bk_depthwise': False,
+        'embed_dim': 96,
+        'expand_ratio': 0.25,
+        'ffn_ratio': 4.0,
+        'num_branches': 4,
+        'num_stages'  : [3, 3, 9, 3],
+    }
+    model, feats = build_backbone(cfg)
+    x = torch.randn(1, 3, 640, 640)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    for out in outputs:
+        print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 185 - 0
models/detectors/rtcdetv2/rtcdetv2_basic.py

@@ -0,0 +1,185 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+
+# ---------------------------- 2D CNN ----------------------------
+class SiLU(nn.Module):
+    """export-friendly version of nn.SiLU()"""
+
+    @staticmethod
+    def forward(x):
+        return x * torch.sigmoid(x)
+
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type is None:
+        return nn.Identity()
+
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+    elif norm_type is None:
+        return nn.Identity()
+        
+class Conv(nn.Module):
+    def __init__(self, 
+                 c1,                   # in channels
+                 c2,                   # out channels 
+                 k=1,                  # kernel size 
+                 p=0,                  # padding
+                 s=1,                  # padding
+                 d=1,                  # dilation
+                 act_type='lrelu',     # activation
+                 norm_type='BN',       # normalization
+                 depthwise=False):
+        super(Conv, self).__init__()
+        convs = []
+        add_bias = False if norm_type else True
+        p = p if d == 1 else d
+
+        if depthwise:
+            # Depthwise Conv
+            assert c1 == c2
+            convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
+            # depthwise conv
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+        else:
+            # Naive Conv
+            convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+            
+        self.convs = nn.Sequential(*convs)
+
+
+    def forward(self, x):
+        return self.convs(x)
+
+
+# ----------------------------  Modules ----------------------------
+## Mixed ConvModule
+class MixedConvModule(nn.Module):
+    def __init__(self,
+                 in_dim       :int,
+                 out_dim      :int,
+                 expand_ratio :float = 0.25,
+                 num_branches :int   = 4,
+                 shortcut     :bool  = True,
+                 act_type     :str   = 'relu',
+                 norm_type    :str   = 'BN',
+                 depthwise    :bool  = False):
+        super(MixedConvModule, self).__init__()
+        # ----------- Basic Parameters -----------
+        self.in_dim = in_dim
+        self.out_dim = out_dim
+        self.expand_ratio = expand_ratio
+        self.num_branches = num_branches
+        self.shortcut = shortcut
+        self.inter_dim = round(in_dim * expand_ratio)
+        # ----------- Network Parameters -----------
+        self.input_proj = Conv(in_dim, self.inter_dim, k=1, act_type=None, norm_type=norm_type)
+        self.branches = nn.ModuleList([
+            Conv(self.inter_dim, self.inter_dim, k=3, p=1, s=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+            for _ in range(num_branches)])
+        self.output_proj = Conv(self.inter_dim * self.num_branches, out_dim, k=1, act_type=act_type, norm_type=norm_type)
+
+    def forward(self, x):
+        y = self.input_proj(x)
+        outs = []
+        for layer in self.branches:
+            y = layer(y)
+            outs.append(y)
+        outs = torch.cat(outs, dim=1)
+
+        return x + self.output_proj(outs) if self.shortcut else self.output_proj(outs)
+
+## Conv-style FFN
+class ConvFFN(nn.Module):
+    def __init__(self,
+                 in_dim       :int,
+                 out_dim      :int,
+                 expand_ratio :float = 2.0,
+                 shortcut     :bool  = True,
+                 act_type     :str   = 'silu',
+                 norm_type    :str   = 'BN',
+                 depthwise    :bool  = False):
+        super(ConvFFN, self).__init__()
+        # ----------- Basic Parameters -----------
+        self.in_dim = in_dim
+        self.out_dim = out_dim
+        self.shortcut = shortcut
+        self.expand_dim = round(in_dim * expand_ratio)
+        # ----------- Network Parameters -----------
+        self.conv_ffn = nn.Sequential(
+            Conv(in_dim, self.expand_dim, k=1, act_type=act_type, norm_type=norm_type),
+            Conv(self.expand_dim, in_dim, k=1, act_type=None, norm_type=norm_type)
+        )
+
+    def forward(self, x):
+        return x + self.conv_ffn(x) if self.shortcut else self.conv_ffn(x)
+
+## ResBlock
+class ResXBlock(nn.Module):
+    def __init__(self,
+                 in_dim       :int,
+                 out_dim      :int,
+                 expand_ratio :float = 0.25,
+                 ffn_ratio    :float = 2.0,
+                 num_branches :int   = 4,
+                 shortcut     :bool  = True,
+                 act_type     :str   ='silu',
+                 norm_type    :str   ='BN',
+                 depthwise    :bool  = False):
+        super(ResXBlock, self).__init__()
+        self.layer1 = MixedConvModule(in_dim, out_dim, expand_ratio, num_branches, shortcut, act_type, norm_type, depthwise)
+        self.layer2 = ConvFFN(out_dim, out_dim, ffn_ratio, shortcut, act_type, norm_type, depthwise)
+
+    def forward(self, x):
+        x = self.layer1(x)
+        x = self.layer2(x)
+        return x
+
+## ResXStage
+class ResXStage(nn.Module):
+    def __init__(self,
+                 in_dim       :int,
+                 out_dim      :int,
+                 expand_ratio :float = 0.25,
+                 ffn_ratio    :float = 2.0,
+                 num_branches :int   = 4,
+                 num_blocks   :int   = 1,
+                 shortcut     :bool  = True,
+                 act_type     :str   ='silu',
+                 norm_type    :str   ='BN',
+                 depthwise    :bool  = False):
+        super(ResXStage, self).__init__()
+        stages = []
+        for i in range(num_blocks):
+            if i == 0:
+                stages.append(ResXBlock(in_dim, out_dim, expand_ratio, ffn_ratio, num_branches, shortcut, act_type, norm_type, depthwise))
+            else:
+                stages.append(ResXBlock(out_dim, out_dim, expand_ratio, ffn_ratio, num_branches, shortcut, act_type, norm_type, depthwise))
+        self.stages = nn.Sequential(*stages)
+
+    def forward(self, x):
+        return self.stages(x)

+ 181 - 0
models/detectors/rtcdetv2/rtcdetv2_pafpn.py

@@ -0,0 +1,181 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+try:
+    from .rtcdetv2_basic import Conv, ResXStage
+except:
+    from rtcdetv2_basic import Conv, ResXStage
+
+
+# PaFPN-CSP
+class RTCDetv2PaFPN(nn.Module):
+    def __init__(self, 
+                 in_dims=[256, 512, 1024],
+                 out_dim=256,
+                 width=1.0,
+                 depth=1.0,
+                 act_type='silu',
+                 norm_type='BN',
+                 depthwise=False):
+        super(RTCDetv2PaFPN, self).__init__()
+        # ------------- Basic parameters -------------
+        self.in_dims = in_dims
+        self.out_dim = out_dim
+        self.expand_ratios = [0.25, 0.25, 0.25, 0.25]
+        self.ffn_ratios = [4.0, 4.0, 4.0, 4.0]
+        self.num_branches = [4, 4, 4, 4]
+        self.num_blocks = [round(2 * depth), round(2 * depth), round(2 * depth), round(2 * depth)]
+        c3, c4, c5 = in_dims
+
+        # top down
+        ## P5 -> P4
+        self.reduce_layer_1 = Conv(c5, round(384*width), k=1, act_type=act_type, norm_type=norm_type)
+        self.top_down_layer_1 = ResXStage(in_dim       = c4 + round(384*width),
+                                          out_dim      = int(384*width),
+                                          expand_ratio = self.expand_ratios[0],
+                                          ffn_ratio    = self.ffn_ratios[0],
+                                          num_branches = self.num_branches[0],
+                                          num_blocks   = self.num_blocks[0],
+                                          shortcut     = False,
+                                          act_type     = act_type,
+                                          norm_type    = norm_type,
+                                          depthwise    = depthwise
+                                          )
+
+        ## P4 -> P3
+        self.reduce_layer_2 = Conv(c4, round(192*width), k=1, norm_type=norm_type, act_type=act_type)
+        self.top_down_layer_2 = ResXStage(in_dim       = c3 + round(192*width), 
+                                          out_dim      = round(192*width),
+                                          expand_ratio = self.expand_ratios[1],
+                                          ffn_ratio    = self.ffn_ratios[1],
+                                          num_branches = self.num_branches[1],
+                                          num_blocks   = self.num_blocks[1],
+                                          shortcut     = False,
+                                          act_type     = act_type,
+                                          norm_type    = norm_type,
+                                          depthwise    = depthwise
+                                          )
+
+        # bottom up
+        ## P3 -> P4
+        self.downsample_layer_1 = Conv(round(192*width), round(192*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.bottom_up_layer_1 = ResXStage(in_dim       = round(192*width) + round(192*width),
+                                           out_dim      = round(384*width),
+                                           expand_ratio = self.expand_ratios[2],
+                                           ffn_ratio    = self.ffn_ratios[2],
+                                           num_branches = self.num_branches[2],
+                                           num_blocks   = self.num_blocks[2],
+                                           shortcut     = False,
+                                           act_type     = act_type,
+                                           norm_type    = norm_type,
+                                           depthwise    = depthwise
+                                           )
+
+        ## P4 -> P5
+        self.downsample_layer_2 = Conv(round(384*width), round(384*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.bottom_up_layer_2 = ResXStage(in_dim       = round(384*width) + round(384*width),
+                                           out_dim      = round(768*width),
+                                           expand_ratio = self.expand_ratios[3],
+                                           ffn_ratio    = self.ffn_ratios[3],
+                                           num_branches = self.num_branches[3],
+                                           num_blocks   = self.num_blocks[3],
+                                           shortcut     = False,
+                                           act_type     = act_type,
+                                           norm_type    = norm_type,
+                                           depthwise    = depthwise
+                                           )
+
+        # output proj layers
+        if out_dim is not None:
+            # output proj layers
+            self.out_layers = nn.ModuleList([
+                Conv(in_dim, out_dim, k=1,
+                        norm_type=norm_type, act_type=act_type)
+                        for in_dim in [round(192 * width), round(384 * width), round(768 * width)]
+                        ])
+            self.out_dim = [out_dim] * 3
+
+        else:
+            self.out_layers = None
+            self.out_dim = [round(192 * width), round(384 * width), round(768 * width)]
+
+
+    def forward(self, features):
+        c3, c4, c5 = features
+
+        c6 = self.reduce_layer_1(c5)
+        c7 = F.interpolate(c6, scale_factor=2.0)   # s32->s16
+        c8 = torch.cat([c7, c4], dim=1)
+        c9 = self.top_down_layer_1(c8)
+        # P3/8
+        c10 = self.reduce_layer_2(c9)
+        c11 = F.interpolate(c10, scale_factor=2.0)   # s16->s8
+        c12 = torch.cat([c11, c3], dim=1)
+        c13 = self.top_down_layer_2(c12)  # to det
+        # p4/16
+        c14 = self.downsample_layer_1(c13)
+        c15 = torch.cat([c14, c10], dim=1)
+        c16 = self.bottom_up_layer_1(c15)  # to det
+        # p5/32
+        c17 = self.downsample_layer_2(c16)
+        c18 = torch.cat([c17, c6], dim=1)
+        c19 = self.bottom_up_layer_2(c18)  # to det
+
+        out_feats = [c13, c16, c19] # [P3, P4, P5]
+
+        # output proj layers
+        if self.out_layers is not None:
+            # output proj layers
+            out_feats_proj = []
+            for feat, layer in zip(out_feats, self.out_layers):
+                out_feats_proj.append(layer(feat))
+            return out_feats_proj
+
+        return out_feats
+
+
+def build_fpn(cfg, in_dims, out_dim=None):
+    model = cfg['fpn']
+    # build neck
+    if model == 'rtcdetv2_pafpn':
+        fpn_net = RTCDetv2PaFPN(in_dims   = in_dims,
+                                out_dim   = out_dim,
+                                width     = cfg['width'],
+                                depth     = cfg['depth'],
+                                act_type  = cfg['fpn_act'],
+                                norm_type = cfg['fpn_norm'],
+                                depthwise = cfg['fpn_depthwise']
+                                )
+
+
+    return fpn_net
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        'width': 1.0,
+        'depth': 1.0,
+        'fpn': 'rtcdetv2_pafpn',
+        'fpn_act': 'silu',
+        'fpn_norm': 'BN',
+        'fpn_depthwise': False,
+    }
+    fpn_dims = [192, 384, 768]
+    out_dim = 192
+    # Head-1
+    model = build_fpn(cfg, fpn_dims, out_dim)
+    fpn_feats = [torch.randn(1, fpn_dims[0], 80, 80), torch.randn(1, fpn_dims[1], 40, 40), torch.randn(1, fpn_dims[2], 20, 20)]
+    t0 = time.time()
+    outputs = model(fpn_feats)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    # for out in outputs:
+    #     print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(fpn_feats, ), verbose=False)
+    print('==============================')
+    print('FPN: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('FPN: Params : {:.2f} M'.format(params / 1e6))