|
|
@@ -20,14 +20,21 @@ class SetCriterion(nn.Module):
|
|
|
super().__init__()
|
|
|
self.num_classes = cfg.num_classes
|
|
|
self.losses = ['labels', 'boxes']
|
|
|
+ self.eos_coef = 0.1
|
|
|
+
|
|
|
# -------- Loss weights --------
|
|
|
self.weight_dict = {'loss_cls': cfg.loss_cls,
|
|
|
'loss_box': cfg.loss_box,
|
|
|
'loss_giou': cfg.loss_giou}
|
|
|
for i in range(cfg.num_dec_layers - 1):
|
|
|
self.weight_dict.update({k + f'_aux_{i}': v for k, v in self.weight_dict.items()})
|
|
|
+ empty_weight = torch.ones(self.num_classes + 1)
|
|
|
+ empty_weight[-1] = self.eos_coef
|
|
|
+ self.register_buffer('empty_weight', empty_weight)
|
|
|
+
|
|
|
# -------- Matcher --------
|
|
|
- self.matcher = HungarianMatcher(cfg.cost_class, cfg.cost_bbox, cfg.cost_giou)
|
|
|
+ matcher_hpy = cfg.matcher_hpy
|
|
|
+ self.matcher = HungarianMatcher(matcher_hpy['cost_class'], matcher_hpy['cost_bbox'], matcher_hpy['cost_giou'])
|
|
|
|
|
|
def loss_labels(self, outputs, targets, indices, num_boxes):
|
|
|
assert 'pred_logits' in outputs
|
|
|
@@ -37,7 +44,6 @@ class SetCriterion(nn.Module):
|
|
|
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
|
|
|
target_classes = torch.full(src_logits.shape[:2], self.num_classes,
|
|
|
dtype=torch.int64, device=src_logits.device)
|
|
|
- target_classes[idx] = target_classes_o
|
|
|
|
|
|
loss_cls = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
|
|
|
|