ソースを参照

train RTCDet-v2-N on COCO

yjh0410 2 年 前
コミット
0a8380ece1

+ 12 - 18
config/model_config/rtcdet_v2_config.py

@@ -42,23 +42,20 @@ rtcdet_v2_cfg = {
         # ---------------- Train config ----------------
         ## Input
         'multi_scale': [0.5, 1.5],   # 320 -> 960
-        'trans_type': 'yolov5_nano',
+        'trans_type': 'rtcdet_v1_nano',
         # ---------------- Assignment config ----------------
         ## Matcher
-        'matcher': {'tal': {'topk': 10,
-                            'alpha': 0.5,
-                            'beta': 6.0},
-                    'ota': {'center_sampling_radius': 2.5,
+        'matcher': {'ota': {'center_sampling_radius': 2.5,
                              'topk_candidate': 10},
                     },
         # ---------------- Loss config ----------------
         ## Loss weight
         'ema_update': False,
-        'loss_cls_weight': 0.5,
-        'loss_box_weight': 7.0,
-        'loss_dfl_weight': 1.5,
+        'loss_cls_weight': 1.0,
+        'loss_box_weight': 5.0,
+        'loss_dfl_weight': 1.0,
         # ---------------- Train config ----------------
-        'trainer_type': 'yolov8',
+        'trainer_type': 'rtmdet',
     },
 
     'rtcdet_v2_l':{
@@ -101,23 +98,20 @@ rtcdet_v2_cfg = {
         # ---------------- Train config ----------------
         ## Input
         'multi_scale': [0.5, 1.25],   # 320 -> 800
-        'trans_type': 'yolov5_large',
+        'trans_type': 'rtcdet_v1_large',
         # ---------------- Assignment config ----------------
         ## Matcher
-        'matcher': {'tal': {'topk': 10,
-                            'alpha': 0.5,
-                            'beta': 6.0},
-                    'ota': {'center_sampling_radius': 2.5,
+        'matcher': {'ota': {'center_sampling_radius': 2.5,
                              'topk_candidate': 10},
                     },
         # ---------------- Loss config ----------------
         ## Loss weight
         'ema_update': False,
-        'loss_cls_weight': 0.5,
-        'loss_box_weight': 7.0,
-        'loss_dfl_weight': 1.5,
+        'loss_cls_weight': 1.0,
+        'loss_box_weight': 5.0,
+        'loss_dfl_weight': 1.0,
         # ---------------- Train config ----------------
-        'trainer_type': 'yolov8',
+        'trainer_type': 'rtmdet',
     },
 
 }

+ 2 - 124
models/detectors/rtcdet_v2/loss.py

@@ -4,7 +4,7 @@ import torch.nn.functional as F
 from utils.box_ops import bbox2dist, get_ious
 from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
 
-from .matcher import TaskAlignedAssigner, AlignedSimOTA
+from .matcher import AlignedSimOTA
 
 
 class Criterion(object):
@@ -20,13 +20,6 @@ class Criterion(object):
         self.loss_dfl_weight = cfg['loss_dfl_weight']
         # ---------------- Matcher ----------------
         matcher_config = cfg['matcher']
-        ## TAL assigner
-        self.tal_matcher = TaskAlignedAssigner(
-            topk=matcher_config['tal']['topk'],
-            alpha=matcher_config['tal']['alpha'],
-            beta=matcher_config['tal']['beta'],
-            num_classes=num_classes
-            )
         ## SimOTA assigner
         self.ota_matcher = AlignedSimOTA(
             center_sampling_radius=matcher_config['ota']['center_sampling_radius'],
@@ -34,12 +27,6 @@ class Criterion(object):
             num_classes=num_classes
         )
 
-    def __call__(self, outputs, targets, epoch=0):
-        if epoch < self.args.wp_epoch:
-            return self.ota_loss(outputs, targets)
-        else:
-            return self.tal_loss(outputs, targets)
-
     def ema_update(self, name: str, value, initial_value, momentum=0.9):
         if hasattr(self, name):
             old = getattr(self, name)
@@ -106,117 +93,8 @@ class Criterion(object):
 
         return loss_dfl
     
-    # ----------------- Loss with TAL assigner -----------------
-    def tal_loss(self, outputs, targets):
-        """ Compute loss with TAL assigner """
-        bs = outputs['pred_cls'][0].shape[0]
-        device = outputs['pred_cls'][0].device
-        anchors = torch.cat(outputs['anchors'], dim=0)
-        num_anchors = anchors.shape[0]
-        # preds: [B, M, C]
-        cls_preds = torch.cat(outputs['pred_cls'], dim=1)
-        reg_preds = torch.cat(outputs['pred_reg'], dim=1)
-        box_preds = torch.cat(outputs['pred_box'], dim=1)
-
-        # --------------- label assignment ---------------
-        gt_label_targets = []
-        gt_score_targets = []
-        gt_bbox_targets = []
-        fg_masks = []
-        for batch_idx in range(bs):
-            tgt_labels = targets[batch_idx]["labels"].to(device)
-            tgt_bboxes = targets[batch_idx]["boxes"].to(device)
-
-            # check target
-            if len(tgt_labels) == 0 or tgt_bboxes.max().item() == 0.:
-                # There is no valid gt
-                fg_mask = cls_preds.new_zeros(1, num_anchors).bool()               #[1, M,]
-                gt_label = cls_preds.new_zeros((1, num_anchors,))                  #[1, M,]
-                gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
-                gt_box = cls_preds.new_zeros((1, num_anchors, 4))                  #[1, M, 4]
-            else:
-                tgt_labels = tgt_labels[None, :, None]      # [1, Mp, 1]
-                tgt_bboxes = tgt_bboxes[None]                   # [1, Mp, 4]
-                (
-                    gt_label,   #[1, M]
-                    gt_box,     #[1, M, 4]
-                    gt_score,   #[1, M, C]
-                    fg_mask,    #[1, M,]
-                    _
-                ) = self.tal_matcher(
-                    pd_scores = cls_preds[batch_idx:batch_idx+1].detach().sigmoid(), 
-                    pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
-                    anc_points = anchors,
-                    gt_labels = tgt_labels,
-                    gt_bboxes = tgt_bboxes
-                    )
-            gt_label_targets.append(gt_label)
-            gt_score_targets.append(gt_score)
-            gt_bbox_targets.append(gt_box)
-            fg_masks.append(fg_mask)
-
-        # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
-        fg_masks = torch.cat(fg_masks, 0).view(-1)                                    # [BM,]
-        gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes)  # [BM, C]
-        gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4)                   # [BM, 4]
-        gt_label_targets = torch.cat(gt_label_targets, 0).view(-1)                    # [BM,]
-        gt_label_targets = torch.where(fg_masks > 0, gt_label_targets, torch.full_like(gt_label_targets, self.num_classes))
-        gt_labels_one_hot = F.one_hot(gt_label_targets.long(), self.num_classes + 1)[..., :-1]
-        bbox_weight = gt_score_targets[fg_masks].sum(-1)
-        num_fgs = max(gt_score_targets.sum(), 1)
-
-        # average loss normalizer across all the GPUs
-        if is_dist_avail_and_initialized():
-            torch.distributed.all_reduce(num_fgs)
-        num_fgs = max(num_fgs / get_world_size(), 1.0)
-
-        # update loss normalizer with EMA
-        if self.use_ema_update:
-            normalizer = self.ema_update("loss_normalizer", max(num_fgs, 1), 100)
-        else:
-            normalizer = num_fgs
-
-        # ------------------ Classification loss ------------------
-        cls_preds = cls_preds.view(-1, self.num_classes)
-        loss_cls = self.loss_classes(cls_preds, gt_score_targets, gt_labels_one_hot, vfl=False)
-        loss_cls = loss_cls.sum() / normalizer
-
-        # ------------------ Regression loss ------------------
-        box_preds_pos = box_preds.view(-1, 4)[fg_masks]
-        box_targets_pos = gt_bbox_targets[fg_masks]
-        loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos, bbox_weight)
-        loss_box = loss_box.sum() / normalizer
-
-        # ------------------ Distribution focal loss  ------------------
-        ## process anchors
-        anchors = anchors[None].repeat(bs, 1, 1).view(-1, 2)
-        ## process stride tensors
-        strides = torch.cat(outputs['stride_tensor'], dim=0)
-        strides = strides.unsqueeze(0).repeat(bs, 1, 1).view(-1, 1)
-        ## fg preds
-        reg_preds_pos = reg_preds.view(-1, 4*self.cfg['reg_max'])[fg_masks]
-        anchors_pos = anchors[fg_masks]
-        strides_pos = strides[fg_masks]
-        ## compute dfl
-        loss_dfl = self.loss_dfl(reg_preds_pos, box_targets_pos, anchors_pos, strides_pos, bbox_weight)
-        loss_dfl = loss_dfl.sum() / normalizer
-
-        # total loss
-        losses = self.loss_cls_weight * loss_cls + \
-                 self.loss_box_weight * loss_box + \
-                 self.loss_dfl_weight * loss_dfl
-
-        loss_dict = dict(
-                loss_cls = loss_cls,
-                loss_box = loss_box,
-                loss_dfl = loss_dfl,
-                losses = losses
-        )
-
-        return loss_dict
-    
     # ----------------- Loss with SimOTA assigner -----------------
-    def ota_loss(self, outputs, targets):
+    def __call__(self, outputs, targets, epoch=0):
         """ Compute loss with SimOTA assigner """
         bs = outputs['pred_cls'][0].shape[0]
         device = outputs['pred_cls'][0].device

+ 2 - 194
models/detectors/rtcdet_v2/matcher.py

@@ -1,201 +1,8 @@
 import torch
-import torch.nn as nn
 import torch.nn.functional as F
-from utils.box_ops import box_iou, bbox_iou
+from utils.box_ops import box_iou
 
 
-# -------------------------- Basic Functions --------------------------
-def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
-    """select the positive anchors's center in gt
-    Args:
-        xy_centers (Tensor): shape(bs*n_max_boxes, num_total_anchors, 4)
-        gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
-    Return:
-        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-    """
-    n_anchors = xy_centers.size(0)
-    bs, n_max_boxes, _ = gt_bboxes.size()
-    _gt_bboxes = gt_bboxes.reshape([-1, 4])
-    xy_centers = xy_centers.unsqueeze(0).repeat(bs * n_max_boxes, 1, 1)
-    gt_bboxes_lt = _gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, n_anchors, 1)
-    gt_bboxes_rb = _gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, n_anchors, 1)
-    b_lt = xy_centers - gt_bboxes_lt
-    b_rb = gt_bboxes_rb - xy_centers
-    bbox_deltas = torch.cat([b_lt, b_rb], dim=-1)
-    bbox_deltas = bbox_deltas.reshape([bs, n_max_boxes, n_anchors, -1])
-    return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)
-
-def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
-    """if an anchor box is assigned to multiple gts,
-        the one with the highest iou will be selected.
-    Args:
-        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-        overlaps (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-    Return:
-        target_gt_idx (Tensor): shape(bs, num_total_anchors)
-        fg_mask (Tensor): shape(bs, num_total_anchors)
-        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-    """
-    fg_mask = mask_pos.sum(axis=-2)
-    if fg_mask.max() > 1:
-        mask_multi_gts = (fg_mask.unsqueeze(1) > 1).repeat([1, n_max_boxes, 1])
-        max_overlaps_idx = overlaps.argmax(axis=1)
-        is_max_overlaps = F.one_hot(max_overlaps_idx, n_max_boxes)
-        is_max_overlaps = is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype)
-        mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos)
-        fg_mask = mask_pos.sum(axis=-2)
-    target_gt_idx = mask_pos.argmax(axis=-2)
-    return target_gt_idx, fg_mask , mask_pos
-
-def iou_calculator(box1, box2, eps=1e-9):
-    """Calculate iou for batch
-    Args:
-        box1 (Tensor): shape(bs, n_max_boxes, 1, 4)
-        box2 (Tensor): shape(bs, 1, num_total_anchors, 4)
-    Return:
-        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
-    """
-    box1 = box1.unsqueeze(2)  # [N, M1, 4] -> [N, M1, 1, 4]
-    box2 = box2.unsqueeze(1)  # [N, M2, 4] -> [N, 1, M2, 4]
-    px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
-    gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
-    x1y1 = torch.maximum(px1y1, gx1y1)
-    x2y2 = torch.minimum(px2y2, gx2y2)
-    overlap = (x2y2 - x1y1).clip(0).prod(-1)
-    area1 = (px2y2 - px1y1).clip(0).prod(-1)
-    area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
-    union = area1 + area2 - overlap + eps
-
-    return overlap / union
-
-
-# -------------------------- Task Aligned Assigner --------------------------
-class TaskAlignedAssigner(nn.Module):
-    def __init__(self, topk=10, alpha=0.5, beta=6.0, eps=1e-9, num_classes=80):
-        super(TaskAlignedAssigner, self).__init__()
-        self.topk = topk
-        self.num_classes = num_classes
-        self.bg_idx = num_classes
-        self.alpha = alpha
-        self.beta = beta
-        self.eps = eps
-
-    @torch.no_grad()
-    def forward(self,
-                pd_scores,
-                pd_bboxes,
-                anc_points,
-                gt_labels,
-                gt_bboxes):
-        """This code referenced to
-           https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
-        Args:
-            pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
-            pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
-            anc_points (Tensor): shape(num_total_anchors, 2)
-            gt_labels (Tensor): shape(bs, n_max_boxes, 1)
-            gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
-        Returns:
-            target_labels (Tensor): shape(bs, num_total_anchors)
-            target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
-            target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
-            fg_mask (Tensor): shape(bs, num_total_anchors)
-        """
-        self.bs = pd_scores.size(0)
-        self.n_max_boxes = gt_bboxes.size(1)
-
-        mask_pos, align_metric, overlaps = self.get_pos_mask(
-            pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points)
-
-        target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(
-            mask_pos, overlaps, self.n_max_boxes)
-
-        # assigned target
-        target_labels, target_bboxes, target_scores = self.get_targets(
-            gt_labels, gt_bboxes, target_gt_idx, fg_mask)
-
-        # normalize
-        align_metric *= mask_pos
-        pos_align_metrics = align_metric.amax(axis=-1, keepdim=True)  # b, max_num_obj
-        pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True)  # b, max_num_obj
-        norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
-        target_scores = target_scores * norm_align_metric
-
-        return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
-
-
-    def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points):
-        # get anchor_align metric, (b, max_num_obj, h*w)
-        align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes)
-        # get in_gts mask, (b, max_num_obj, h*w)
-        mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
-        # get topk_metric mask, (b, max_num_obj, h*w)
-        mask_topk = self.select_topk_candidates(align_metric * mask_in_gts)
-        # merge all mask to a final mask, (b, max_num_obj, h*w)
-        mask_pos = mask_topk * mask_in_gts
-
-        return mask_pos, align_metric, overlaps
-
-
-    def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes):
-        ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
-        ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes)  # b, max_num_obj
-        ind[1] = gt_labels.long().squeeze(-1)  # b, max_num_obj
-        # get the scores of each grid for each gt cls
-        bbox_scores = pd_scores[ind[0], :, ind[1]]  # b, max_num_obj, h*w
-
-        overlaps = bbox_iou(gt_bboxes.unsqueeze(2), pd_bboxes.unsqueeze(1), xywh=False).squeeze(3).clamp(0)
-        align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
-
-        return align_metric, overlaps
-
-
-    def select_topk_candidates(self, metrics, largest=True):
-        """
-        Args:
-            metrics: (b, max_num_obj, h*w).
-            topk_mask: (b, max_num_obj, topk) or None
-        """
-
-        num_anchors = metrics.shape[-1]  # h*w
-        # (b, max_num_obj, topk)
-        topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
-        topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).tile([1, 1, self.topk])
-        # (b, max_num_obj, topk)
-        topk_idxs[~topk_mask] = 0
-        # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
-        is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2)
-        # filter invalid bboxes
-        is_in_topk = torch.where(is_in_topk > 1, 0, is_in_topk)
-        return is_in_topk.to(metrics.dtype)
-
-
-    def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
-        """
-        Args:
-            gt_labels: (b, max_num_obj, 1)
-            gt_bboxes: (b, max_num_obj, 4)
-            target_gt_idx: (b, h*w)
-            fg_mask: (b, h*w)
-        """
-
-        # assigned target labels, (b, 1)
-        batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
-        target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes  # (b, h*w)
-        target_labels = gt_labels.long().flatten()[target_gt_idx]  # (b, h*w)
-
-        # assigned target boxes, (b, max_num_obj, 4) -> (b, h*w)
-        target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
-
-        # assigned target scores
-        target_labels.clamp(0)
-        target_scores = F.one_hot(target_labels, self.num_classes)  # (b, h*w, 80)
-        fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)  # (b, h*w, 80)
-        target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
-
-        return target_labels, target_bboxes, target_scores
-    
-
 # -------------------------- Aligned SimOTA Assigner --------------------------
 class AlignedSimOTA(object):
     """
@@ -240,6 +47,7 @@ class AlignedSimOTA(object):
             # prepare cls_target
             cls_targets = F.one_hot(tgt_labels.long(), self.num_classes).float()
             cls_targets = cls_targets.unsqueeze(1).repeat(1, score_preds.size(1), 1)
+            cls_targets *= pair_wise_ious.unsqueeze(-1)  # iou-aware
             # [N, Mp]
             cls_cost = F.binary_cross_entropy(score_preds, cls_targets, reduction="none").sum(-1)
         del score_preds

+ 1 - 1
models/detectors/rtcdet_v2/rtcdet_v2_backbone.py

@@ -9,7 +9,7 @@ except:
 
 model_urls = {
     'mcnet_p': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/mcnet_pico.pth",
-    'mcnet_n': None,
+    'mcnet_n': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/mcnet_nano.pth",
     'mcnet_t': None,
     'mcnet_s': None,
     'mcnet_m': None,

+ 1 - 1
train_ddp.sh

@@ -5,7 +5,7 @@ python -m torch.distributed.run --nproc_per_node=8 train.py \
                                                     -dist \
                                                     -d coco \
                                                     --root /data/datasets/ \
-                                                    -m rtcdet_v2_l\
+                                                    -m rtcdet_v2_n\
                                                     -bs 128 \
                                                     -size 640 \
                                                     --wp_epoch 3 \