|
@@ -20,6 +20,7 @@ class Criterion(object):
|
|
|
self.loss_cls_weight = cfg['loss_cls_weight']
|
|
self.loss_cls_weight = cfg['loss_cls_weight']
|
|
|
self.loss_box_weight = cfg['loss_box_weight']
|
|
self.loss_box_weight = cfg['loss_box_weight']
|
|
|
self.loss_dfl_weight = cfg['loss_dfl_weight']
|
|
self.loss_dfl_weight = cfg['loss_dfl_weight']
|
|
|
|
|
+ self.loss_box_aux = cfg['loss_box_aux']
|
|
|
# ---------------- Matcher ----------------
|
|
# ---------------- Matcher ----------------
|
|
|
matcher_config = cfg['matcher']
|
|
matcher_config = cfg['matcher']
|
|
|
## TAL assigner
|
|
## TAL assigner
|
|
@@ -116,7 +117,6 @@ class Criterion(object):
|
|
|
|
|
|
|
|
return loss_box_aux
|
|
return loss_box_aux
|
|
|
|
|
|
|
|
-
|
|
|
|
|
# ----------------- Loss with TAL assigner -----------------
|
|
# ----------------- Loss with TAL assigner -----------------
|
|
|
def tal_loss(self, outputs, targets, epoch=0):
|
|
def tal_loss(self, outputs, targets, epoch=0):
|
|
|
""" Compute loss with TAL assigner """
|
|
""" Compute loss with TAL assigner """
|
|
@@ -225,16 +225,10 @@ class Criterion(object):
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
# ------------------ Aux regression loss ------------------
|
|
# ------------------ Aux regression loss ------------------
|
|
|
- if epoch >= (self.max_epoch - self.no_aug_epoch - 1):
|
|
|
|
|
|
|
+ if epoch >= (self.max_epoch - self.no_aug_epoch - 1) and self.loss_box_aux:
|
|
|
## delta_preds
|
|
## delta_preds
|
|
|
delta_preds = torch.cat(outputs['pred_delta'], dim=1)
|
|
delta_preds = torch.cat(outputs['pred_delta'], dim=1)
|
|
|
delta_preds_pos = delta_preds.view(-1, 4)[fg_masks]
|
|
delta_preds_pos = delta_preds.view(-1, 4)[fg_masks]
|
|
|
- ## anchor tensors
|
|
|
|
|
- anchors_tensors = torch.cat(outputs['anchors'], dim=0)[None].repeat(bs, 1, 1)
|
|
|
|
|
- anchors_tensors_pos = anchors_tensors.view(-1, 2)[fg_masks]
|
|
|
|
|
- ## stride tensors
|
|
|
|
|
- stride_tensors = torch.cat(outputs['stride_tensors'], dim=0)[None].repeat(bs, 1, 1)
|
|
|
|
|
- stride_tensors_pos = stride_tensors.view(-1, 1)[fg_masks]
|
|
|
|
|
## aux loss
|
|
## aux loss
|
|
|
loss_box_aux = self.loss_bboxes_aux(delta_preds_pos, box_targets_pos, anchors_pos, strides_pos)
|
|
loss_box_aux = self.loss_bboxes_aux(delta_preds_pos, box_targets_pos, anchors_pos, strides_pos)
|
|
|
loss_box_aux = loss_box_aux.sum() / num_fgs
|
|
loss_box_aux = loss_box_aux.sum() / num_fgs
|
|
@@ -351,7 +345,7 @@ class Criterion(object):
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
# ------------------ Aux regression loss ------------------
|
|
# ------------------ Aux regression loss ------------------
|
|
|
- if epoch >= (self.max_epoch - self.no_aug_epoch - 1):
|
|
|
|
|
|
|
+ if epoch >= (self.max_epoch - self.no_aug_epoch - 1) and self.loss_box_aux:
|
|
|
## delta_preds
|
|
## delta_preds
|
|
|
delta_preds = torch.cat(outputs['pred_delta'], dim=1)
|
|
delta_preds = torch.cat(outputs['pred_delta'], dim=1)
|
|
|
delta_preds_pos = delta_preds.view(-1, 4)[fg_masks]
|
|
delta_preds_pos = delta_preds.view(-1, 4)[fg_masks]
|