|
|
@@ -97,10 +97,11 @@ class SetCriterion(nn.Module):
|
|
|
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
|
|
|
|
|
|
# Compute all the requested losses
|
|
|
- losses = {}
|
|
|
+ loss_dict = {}
|
|
|
for loss in self.losses:
|
|
|
l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes)
|
|
|
- losses.update(l_dict)
|
|
|
+ l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
|
|
|
+ loss_dict.update(l_dict)
|
|
|
|
|
|
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
|
|
if 'aux_outputs' in outputs:
|
|
|
@@ -108,28 +109,11 @@ class SetCriterion(nn.Module):
|
|
|
indices = self.matcher(aux_outputs, targets)
|
|
|
for loss in self.losses:
|
|
|
l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes)
|
|
|
+ l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
|
|
|
l_dict = {k + f'_aux_{i}': v for k, v in l_dict.items()}
|
|
|
- losses.update(l_dict)
|
|
|
+ loss_dict.update(l_dict)
|
|
|
|
|
|
- return losses
|
|
|
+ # Total loss
|
|
|
+ loss_dict["losses"] = sum(loss_dict.values())
|
|
|
|
|
|
- @staticmethod
|
|
|
- def get_cdn_matched_indices(dn_meta, targets):
|
|
|
- '''get_cdn_matched_indices
|
|
|
- '''
|
|
|
- dn_positive_idx, dn_num_group = dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
|
|
|
- num_gts = [len(t['labels']) for t in targets]
|
|
|
- device = targets[0]['labels'].device
|
|
|
-
|
|
|
- dn_match_indices = []
|
|
|
- for i, num_gt in enumerate(num_gts):
|
|
|
- if num_gt > 0:
|
|
|
- gt_idx = torch.arange(num_gt, dtype=torch.int64, device=device)
|
|
|
- gt_idx = gt_idx.tile(dn_num_group)
|
|
|
- assert len(dn_positive_idx[i]) == len(gt_idx)
|
|
|
- dn_match_indices.append((dn_positive_idx[i], gt_idx))
|
|
|
- else:
|
|
|
- dn_match_indices.append((torch.zeros(0, dtype=torch.int64, device=device), \
|
|
|
- torch.zeros(0, dtype=torch.int64, device=device)))
|
|
|
-
|
|
|
- return dn_match_indices
|
|
|
+ return loss_dict
|