Parcourir la source

add ViT into iclab/

yjh0410 il y a 1 an
Parent
commit
b0a8792a14
1 fichiers modifiés avec 11 ajouts et 15 suppressions
  1. 11 15
      yolo/models/yolov6/loss.py

+ 11 - 15
yolo/models/yolov6/loss.py

@@ -23,16 +23,15 @@ class SetCriterion(object):
                                            beta            = cfg.tal_beta
                                            )
 
-    def loss_classes(self, pred_logits, gt_score, gt_label, fg_mask):
-        gt_label = torch.where(fg_mask > 0, gt_label, torch.full_like(gt_label, self.num_classes))
-        one_hot_label = F.one_hot(gt_label.long(), self.num_classes + 1)[..., :-1]
-
-        pred_score = pred_logits.sigmoid()
-        weight = 0.75 * pred_score.pow(2.0) * (1 - one_hot_label) + gt_score * one_hot_label
-        with torch.cuda.amp.autocast(enabled=False):
-            loss_cls = F.binary_cross_entropy_with_logits(pred_logits.float(), gt_score.float(), reduction='none')
-            loss_cls = loss_cls * weight
-            loss_cls = loss_cls.sum()
+    def loss_classes(self, pred_logits, gt_score):
+        alpha, gamma = 0.75, 2.0
+        pred_sigmoid = pred_logits.sigmoid()
+        focal_weight = gt_score * (gt_score > 0.0).float() + \
+            alpha * (pred_sigmoid - gt_score).abs().pow(gamma) * \
+            (gt_score <= 0.0).float()
+        
+        loss_cls = F.binary_cross_entropy_with_logits(
+            pred_logits, gt_score, reduction='none') * focal_weight
 
         return loss_cls
     
@@ -63,7 +62,6 @@ class SetCriterion(object):
         anchors = torch.cat(outputs['anchors'], dim=0)
         
         # --------------- label assignment ---------------
-        gt_label_targets = []
         gt_score_targets = []
         gt_bbox_targets = []
         fg_masks = []
@@ -92,7 +90,7 @@ class SetCriterion(object):
                 tgt_labels = tgt_labels[None, :, None]      # [1, Mp, 1]
                 tgt_boxs = tgt_boxs[None]                   # [1, Mp, 4]
                 (
-                    gt_label,   # [1, M]
+                    _,          # [1, M]
                     gt_box,     # [1, M, 4]
                     gt_score,   # [1, M, C]
                     fg_mask,    # [1, M,]
@@ -104,14 +102,12 @@ class SetCriterion(object):
                     gt_labels = tgt_labels,
                     gt_bboxes = tgt_boxs
                     )
-            gt_label_targets.append(gt_label)
             gt_score_targets.append(gt_score)
             gt_bbox_targets.append(gt_box)
             fg_masks.append(fg_mask)
 
         # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
         fg_masks = torch.cat(fg_masks, 0).view(-1)                                    # [BM,]
-        gt_label_targets = torch.cat(gt_label_targets, 0).view(-1)                    # [BM,]
         gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes)  # [BM, C]
         gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4)                   # [BM, 4]
         num_fgs = gt_score_targets.sum()
@@ -123,7 +119,7 @@ class SetCriterion(object):
 
         # ------------------ Classification loss ------------------
         cls_preds = cls_preds.view(-1, self.num_classes)
-        loss_cls = self.loss_classes(cls_preds, gt_score_targets, gt_label_targets, fg_masks)
+        loss_cls = self.loss_classes(cls_preds, gt_score_targets)
         loss_cls = loss_cls.sum() / num_fgs
 
         # ------------------ Regression loss ------------------