Kaynağa Gözat

add aux loss for RTCDet-v2

yjh0410 2 yıl önce
ebeveyn
işleme
767ac95ea3

+ 45 - 3
models/detectors/rtcdet_v2/loss.py

@@ -13,6 +13,8 @@ class Criterion(object):
         self.args = args
         self.device = device
         self.num_classes = num_classes
+        self.max_epoch = args.max_epoch
+        self.no_aug_epoch = args.no_aug_epoch
         self.use_ema_update = cfg['ema_update']
         # ---------------- Loss weight ----------------
         self.loss_cls_weight = cfg['loss_cls_weight']
@@ -105,9 +107,18 @@ class Criterion(object):
             loss_dfl *= bbox_weight
 
         return loss_dfl
-    
+
+    def loss_bboxes_aux(self, pred_delta, gt_box, anchors, stride_tensors):
+        gt_delta_tl = (anchors - gt_box[..., :2]) / stride_tensors
+        gt_delta_rb = (gt_box[..., 2:] - anchors) / stride_tensors
+        gt_delta = torch.cat([gt_delta_tl, gt_delta_rb], dim=1)
+        loss_box_aux = F.l1_loss(pred_delta, gt_delta, reduction='none')
+
+        return loss_box_aux
+
+
     # ----------------- Loss with TAL assigner -----------------
-    def tal_loss(self, outputs, targets):
+    def tal_loss(self, outputs, targets, epoch=0):
         """ Compute loss with TAL assigner """
         bs = outputs['pred_cls'][0].shape[0]
         device = outputs['pred_cls'][0].device
@@ -213,10 +224,28 @@ class Criterion(object):
                 losses = losses
         )
 
+        # ------------------ Aux regression loss ------------------
+        if epoch >= (self.max_epoch - self.no_aug_epoch - 1):
+            ## delta_preds
+            delta_preds = torch.cat(outputs['pred_delta'], dim=1)
+            delta_preds_pos = delta_preds.view(-1, 4)[fg_masks]
+            ## anchor tensors
+            anchors_tensors = torch.cat(outputs['anchors'], dim=0)[None].repeat(bs, 1, 1)
+            anchors_tensors_pos = anchors_tensors.view(-1, 2)[fg_masks]
+            ## stride tensors
+            stride_tensors = torch.cat(outputs['stride_tensors'], dim=0)[None].repeat(bs, 1, 1)
+            stride_tensors_pos = stride_tensors.view(-1, 1)[fg_masks]
+            ## aux loss
+            loss_box_aux = self.loss_bboxes_aux(delta_preds_pos, box_targets_pos, anchors_pos, strides_pos)
+            loss_box_aux = loss_box_aux.sum() / num_fgs
+
+            losses += loss_box_aux
+            loss_dict['loss_box_aux'] = loss_box_aux
+
         return loss_dict
     
     # ----------------- Loss with SimOTA assigner -----------------
-    def ota_loss(self, outputs, targets):
+    def ota_loss(self, outputs, targets, epoch=0):
         """ Compute loss with SimOTA assigner """
         bs = outputs['pred_cls'][0].shape[0]
         device = outputs['pred_cls'][0].device
@@ -321,6 +350,19 @@ class Criterion(object):
                 losses = losses
         )
 
+        # ------------------ Aux regression loss ------------------
+        if epoch >= (self.max_epoch - self.no_aug_epoch - 1):
+            ## delta_preds
+            delta_preds = torch.cat(outputs['pred_delta'], dim=1)
+            delta_preds_pos = delta_preds.view(-1, 4)[fg_masks]
+            ## aux loss
+            loss_box_aux = self.loss_bboxes_aux(delta_preds_pos, box_targets, anchors_pos, strides_pos)
+            loss_box_aux = loss_box_aux.sum() / num_fgs
+
+            losses += loss_box_aux
+            loss_dict['loss_box_aux'] = loss_box_aux
+
+
         return loss_dict
 
 

+ 9 - 6
models/detectors/rtcdet_v2/rtcdet_v2_pred.py

@@ -97,6 +97,7 @@ class MultiLevelPredLayer(nn.Module):
         all_cls_preds = []
         all_reg_preds = []
         all_box_preds = []
+        all_delta_preds = []
         for level in range(self.num_levels):
             # pred
             cls_pred, reg_pred = self.multi_level_preds[level](cls_feats[level], reg_feats[level])
@@ -116,21 +117,22 @@ class MultiLevelPredLayer(nn.Module):
             # ----------------------- Decode bbox -----------------------
             B, M = reg_pred.shape[:2]
             # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
-            reg_pred_ = reg_pred.reshape([B, M, 4, self.reg_max])
+            delta_pred = reg_pred.reshape([B, M, 4, self.reg_max])
             # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
-            reg_pred_ = reg_pred_.permute(0, 3, 2, 1).contiguous()
+            delta_pred = delta_pred.permute(0, 3, 2, 1).contiguous()
             # [B, reg_max, 4, M] -> [B, 1, 4, M]
-            reg_pred_ = self.proj_conv(F.softmax(reg_pred_, dim=1))
+            delta_pred = self.proj_conv(F.softmax(delta_pred, dim=1))
             # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
-            reg_pred_ = reg_pred_.view(B, 4, M).permute(0, 2, 1).contiguous()
+            delta_pred = delta_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
             ## tlbr -> xyxy
-            x1y1_pred = anchors[None] - reg_pred_[..., :2] * self.strides[level]
-            x2y2_pred = anchors[None] + reg_pred_[..., 2:] * self.strides[level]
+            x1y1_pred = anchors[None] - delta_pred[..., :2] * self.strides[level]
+            x2y2_pred = anchors[None] + delta_pred[..., 2:] * self.strides[level]
             box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
 
             all_cls_preds.append(cls_pred)
             all_reg_preds.append(reg_pred)
             all_box_preds.append(box_pred)
+            all_delta_preds.append(delta_pred)
             all_anchors.append(anchors)
             all_strides.append(stride_tensor)
         
@@ -138,6 +140,7 @@ class MultiLevelPredLayer(nn.Module):
         outputs = {"pred_cls": all_cls_preds,        # List(Tensor) [B, M, C]
                    "pred_reg": all_reg_preds,        # List(Tensor) [B, M, 4*(reg_max)]
                    "pred_box": all_box_preds,        # List(Tensor) [B, M, 4]
+                   "pred_delta": all_delta_preds,    # List(Tensor) [B, M, 4]
                    "anchors": all_anchors,           # List(Tensor) [M, 2]
                    "strides": self.strides,          # List(Int) = [8, 16, 32]
                    "stride_tensor": all_strides      # List(Tensor) [M, 1]

+ 5 - 5
models/detectors/yolox/README.md

@@ -4,11 +4,11 @@
 |---------|--------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|
 | YOLOX-N | CSPDarkNet-N | 8xb8  |  640  |         30.4           |       48.9        |   7.5             |   2.3              | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolox_n_coco.pth) |
 | YOLOX-S | CSPDarkNet-S | 8xb8  |  640  |         39.0           |       58.8        |   26.8            |   8.9              | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolox_s_coco.pth) |
-| YOLOX-M | CSPDarkNet-M | 1xb16 |  640  |         44.6           |       63.8        |   74.3            |   25.4             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolox_m_coco.pth) |
-| YOLOX-L | CSPDarkNet-L | 1xb16 |  640  |         48.7           |       68.0        |   155.4           |   54.2             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolox_l_coco.pth) |
+| YOLOX-M | CSPDarkNet-M | 8xb8 |  640  |         44.6           |       63.8        |   74.3            |   25.4             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolox_m_coco.pth) |
+| YOLOX-L | CSPDarkNet-L | 8xb8 |  640  |         48.7           |       68.0        |   155.4           |   54.2             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolox_l_coco.pth) |
 
 - For training, we train YOLOX series with 300 epochs on COCO.
-- For data augmentation, we use the large scale jitter (LSJ), Mosaic augmentation and Mixup augmentation, following the setting of [YOLOX](https://github.com/ultralytics/yolov5).
-- For optimizer, we use AdamW with weight decay 0.05 and base per image lr 0.001 / 64.
-- For learning rate scheduler, we use linear decay scheduler.
+- For data augmentation, we use the large scale jitter (LSJ), Mosaic augmentation and Mixup augmentation.
+- For optimizer, we use SGD with weight decay 0.0005 and base per image lr 0.01 / 64,.
+- For learning rate scheduler, we use Cosine decay scheduler.
 - I am trying to retrain **YOLOX-M** and **YOLOX-L** with more GPUs, and I will update the AP of YOLOX-M and YOLOX-L in the table in the future.