yjh0410 1 gadu atpakaļ
vecāks
revīzija
b1bcbb18f8

+ 1 - 6
odlab/engine.py

@@ -1,7 +1,3 @@
-# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
-"""
-Train and eval functions used in main.py
-"""
 import math
 import sys
 from typing import Iterable
@@ -59,8 +55,7 @@ def train_one_epoch(cfg,
 
         # Compute loss
         loss_dict = criterion(outputs, targets)
-        loss_weight_dict = criterion.weight_dict
-        losses = sum(loss_dict[k] * loss_weight_dict[k] for k in loss_dict.keys() if k in loss_weight_dict)
+        losses = loss_dict["losses"]# sum(loss_dict[k] * loss_weight_dict[k] for k in loss_dict.keys() if k in loss_weight_dict)
         loss_value = losses.item()
         losses /= cfg.grad_accumulate
 

+ 8 - 24
odlab/models/detectors/detr/criterion.py

@@ -97,10 +97,11 @@ class SetCriterion(nn.Module):
         num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
 
         # Compute all the requested losses
-        losses = {}
+        loss_dict = {}
         for loss in self.losses:
             l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes)
-            losses.update(l_dict)
+            l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
+            loss_dict.update(l_dict)
 
         # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
         if 'aux_outputs' in outputs:
@@ -108,28 +109,11 @@ class SetCriterion(nn.Module):
                 indices = self.matcher(aux_outputs, targets)
                 for loss in self.losses:
                     l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes)
+                    l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
                     l_dict = {k + f'_aux_{i}': v for k, v in l_dict.items()}
-                    losses.update(l_dict)
+                    loss_dict.update(l_dict)
 
-        return losses
+        # Total loss
+        loss_dict["losses"] = sum(loss_dict.values())
 
-    @staticmethod
-    def get_cdn_matched_indices(dn_meta, targets):
-        '''get_cdn_matched_indices
-        '''
-        dn_positive_idx, dn_num_group = dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
-        num_gts = [len(t['labels']) for t in targets]
-        device = targets[0]['labels'].device
-        
-        dn_match_indices = []
-        for i, num_gt in enumerate(num_gts):
-            if num_gt > 0:
-                gt_idx = torch.arange(num_gt, dtype=torch.int64, device=device)
-                gt_idx = gt_idx.tile(dn_num_group)
-                assert len(dn_positive_idx[i]) == len(gt_idx)
-                dn_match_indices.append((dn_positive_idx[i], gt_idx))
-            else:
-                dn_match_indices.append((torch.zeros(0, dtype=torch.int64, device=device), \
-                    torch.zeros(0, dtype=torch.int64,  device=device)))
-        
-        return dn_match_indices
+        return loss_dict

+ 7 - 0
odlab/models/detectors/fcos/criterion.py

@@ -176,10 +176,14 @@ class SetCriterion(nn.Module):
             pred_ctn[foreground_idxs],  gt_centerness[foreground_idxs], reduction='none')
         loss_centerness = loss_centerness.sum() / num_foreground
 
+        total_loss = loss_labels * self.weight_dict["loss_cls"] + \
+                     loss_bboxes * self.weight_dict["loss_reg"] + \
+                     loss_centerness * self.weight_dict["loss_ctn"]
         loss_dict = dict(
                 loss_cls = loss_labels,
                 loss_reg = loss_bboxes,
                 loss_ctn = loss_centerness,
+                losses   = total_loss,
         )
 
         return loss_dict
@@ -254,9 +258,12 @@ class SetCriterion(nn.Module):
         box_weight = assign_metrics[foreground_idxs]
         loss_bboxes = self.loss_bboxes_xyxy(box_preds_pos, box_targets_pos, num_fgs, box_weight)
 
+        total_loss = loss_labels * self.weight_dict["loss_cls"] + \
+                     loss_bboxes * self.weight_dict["loss_reg"]
         loss_dict = dict(
                 loss_cls = loss_labels,
                 loss_reg = loss_bboxes,
+                losses   = total_loss,
         )
 
         return loss_dict

+ 3 - 3
odlab/models/detectors/fcos_e2e/criterion.py

@@ -133,11 +133,11 @@ class SetCriterion(nn.Module):
         loss_bboxes = self.loss_bboxes(box_preds_pos, box_targets_pos, num_fgs, box_weight)
 
         total_loss = loss_labels * self.weight_dict["loss_cls"] + \
-                         loss_bboxes * self.weight_dict["loss_reg"]
+                     loss_bboxes * self.weight_dict["loss_reg"]
         loss_dict = dict(
                 loss_cls = loss_labels,
                 loss_reg = loss_bboxes,
-                loss     = total_loss,
+                losses   = total_loss,
         )
 
         return loss_dict
@@ -158,7 +158,7 @@ class SetCriterion(nn.Module):
         o2o_loss_dict = self.compute_loss(outputs["outputs_o2o"], targets)
 
         loss_dict = {}
-        loss_dict["loss"] = o2o_loss_dict["loss"] + o2m_loss_dict["loss"]
+        loss_dict["losses"] = o2o_loss_dict["losses"] + o2m_loss_dict["losses"]
         for k in o2m_loss_dict:
             loss_dict['o2m_' + k] = o2m_loss_dict[k]
         for k in o2o_loss_dict:

+ 3 - 0
odlab/models/detectors/yolof/criterion.py

@@ -134,9 +134,12 @@ class SetCriterion(nn.Module):
         matched_pred_box = pred_box.reshape(-1, 4)[src_idx[~pos_ignore_idx.cpu()]]
         loss_bboxes = self.loss_bboxes(matched_pred_box, tgt_boxes, num_foreground)
 
+        total_loss = loss_labels * self.weight_dict["loss_cls"] + \
+                     loss_bboxes * self.weight_dict["loss_reg"]
         loss_dict = dict(
                 loss_cls = loss_labels,
                 loss_reg = loss_bboxes,
+                losses   = total_loss,
         )
 
         return loss_dict