yjh0410 1 år sedan
förälder
incheckning
e636098379
1 ändrade filer med 16 tillägg och 13 borttagningar
  1. 16 13
      yolo/models/yolov6/loss.py

+ 16 - 13
yolo/models/yolov6/loss.py

@@ -23,16 +23,19 @@ class SetCriterion(object):
                                            beta            = cfg.tal_beta
                                            )
 
-    def loss_classes(self, pred_cls, gt_score, gt_label):
-        # Compute VFL
-        pred_score = F.sigmoid(pred_cls).detach()
-        target = F.one_hot(gt_label, num_classes=self.num_classes + 1)[..., :-1]
-        weight = 0.75 * pred_score.pow(2.0) * (1 - target) + gt_score
+    def loss_classes(self, pred_logits, gt_score, gt_label, fg_mask):
+        gt_label = torch.where(fg_mask > 0, gt_label, torch.full_like(gt_label, self.num_classes))
+        one_hot_label = F.one_hot(gt_label.long(), self.num_classes + 1)[..., :-1]
 
-        loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_score, weight=weight, reduction='none')
+        pred_score = pred_logits.sigmoid()
+        weight = 0.75 * pred_score.pow(2.0) * (1 - one_hot_label) + gt_score * one_hot_label
+        with torch.cuda.amp.autocast(enabled=False):
+            loss_cls = F.binary_cross_entropy_with_logits(pred_logits.float(), gt_score.float(), reduction='none')
+            loss_cls = loss_cls * weight
+            loss_cls = loss_cls.sum()
 
         return loss_cls
-        
+    
     def loss_bboxes(self, pred_box, gt_box, bbox_weight):
         # regression loss
         ious = bbox_iou(pred_box, gt_box, xywh=False, GIoU=True)
@@ -64,9 +67,9 @@ class SetCriterion(object):
         gt_score_targets = []
         gt_bbox_targets = []
         fg_masks = []
-        for batch_idx in range(bs):
-            tgt_labels = targets[batch_idx]["labels"].to(device)     # [Mp,]
-            tgt_boxs = targets[batch_idx]["boxes"].to(device)        # [Mp, 4]
+        for bid in range(bs):
+            tgt_labels = targets[bid]["labels"].to(device)     # [Mp,]
+            tgt_boxs = targets[bid]["boxes"].to(device)        # [Mp, 4]
 
             if self.cfg.normalize_coords:
                 img_h, img_w = outputs['image_size']
@@ -95,8 +98,8 @@ class SetCriterion(object):
                     fg_mask,    # [1, M,]
                     _
                 ) = self.matcher(
-                    pd_scores = cls_preds[batch_idx:batch_idx+1].detach().sigmoid(), 
-                    pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
+                    pd_scores = cls_preds[bid:bid+1].detach().sigmoid(), 
+                    pd_bboxes = box_preds[bid:bid+1].detach(),
                     anc_points = anchors,
                     gt_labels = tgt_labels,
                     gt_bboxes = tgt_boxs
@@ -120,7 +123,7 @@ class SetCriterion(object):
 
         # ------------------ Classification loss ------------------
         cls_preds = cls_preds.view(-1, self.num_classes)
-        loss_cls = self.loss_classes(cls_preds, gt_score_targets, gt_label_targets)
+        loss_cls = self.loss_classes(cls_preds, gt_score_targets, gt_label_targets, fg_masks)
         loss_cls = loss_cls.sum() / num_fgs
 
         # ------------------ Regression loss ------------------