|
@@ -15,7 +15,7 @@ class Criterion(object):
|
|
|
self.num_classes = num_classes
|
|
self.num_classes = num_classes
|
|
|
self.max_epoch = args.max_epoch
|
|
self.max_epoch = args.max_epoch
|
|
|
self.no_aug_epoch = args.no_aug_epoch
|
|
self.no_aug_epoch = args.no_aug_epoch
|
|
|
- self.aux_bbox_loss = False
|
|
|
|
|
|
|
+ self.aux_bbox_loss = cfg['aux_bbox_loss']
|
|
|
# --------------- Loss config ---------------
|
|
# --------------- Loss config ---------------
|
|
|
self.loss_cls_weight = cfg['loss_cls_weight']
|
|
self.loss_cls_weight = cfg['loss_cls_weight']
|
|
|
self.loss_box_weight = cfg['loss_box_weight']
|
|
self.loss_box_weight = cfg['loss_box_weight']
|
|
@@ -168,7 +168,7 @@ class Criterion(object):
|
|
|
|
|
|
|
|
# ------------------ Aux regression loss ------------------
|
|
# ------------------ Aux regression loss ------------------
|
|
|
loss_box_aux = None
|
|
loss_box_aux = None
|
|
|
- if epoch >= (self.max_epoch - self.no_aug_epoch - 1):
|
|
|
|
|
|
|
+ if epoch >= (self.max_epoch - self.no_aug_epoch - 1) and self.aux_bbox_loss:
|
|
|
## reg_preds
|
|
## reg_preds
|
|
|
reg_preds = outputs['pred_reg']
|
|
reg_preds = outputs['pred_reg']
|
|
|
reg_preds_pos = reg_preds.view(-1, 4)[pos_inds]
|
|
reg_preds_pos = reg_preds.view(-1, 4)[pos_inds]
|
|
@@ -203,10 +203,10 @@ class Criterion(object):
|
|
|
|
|
|
|
|
def __call__(self, outputs, targets, epoch=0):
|
|
def __call__(self, outputs, targets, epoch=0):
|
|
|
# -------------- Main loss --------------
|
|
# -------------- Main loss --------------
|
|
|
- main_loss_dict = self.compute_loss(outputs, targets, epoch)
|
|
|
|
|
|
|
+ main_loss_dict = self.compute_loss(outputs, targets, False, epoch)
|
|
|
|
|
|
|
|
# -------------- Aux loss --------------
|
|
# -------------- Aux loss --------------
|
|
|
- aux_loss_dict = self.compute_loss(outputs['aux_outputs'], targets, epoch)
|
|
|
|
|
|
|
+ aux_loss_dict = self.compute_loss(outputs['aux_outputs'], targets, True, epoch)
|
|
|
|
|
|
|
|
# Reformat loss dict
|
|
# Reformat loss dict
|
|
|
loss_dict = dict()
|
|
loss_dict = dict()
|