|
|
@@ -7,13 +7,18 @@ from utils.distributed_utils import get_world_size, is_dist_avail_and_initialize
|
|
|
|
|
|
|
|
|
class Criterion(object):
|
|
|
- def __init__(self,
|
|
|
+ def __init__(self,
|
|
|
+ args,
|
|
|
cfg,
|
|
|
device,
|
|
|
num_classes=80):
|
|
|
+ self.args = args
|
|
|
self.cfg = cfg
|
|
|
self.device = device
|
|
|
self.num_classes = num_classes
|
|
|
+ self.max_epoch = args.max_epoch
|
|
|
+ self.no_aug_epoch = args.no_aug_epoch
|
|
|
+ self.aux_bbox_loss = False
|
|
|
# loss weight
|
|
|
self.loss_obj_weight = cfg['loss_obj_weight']
|
|
|
self.loss_cls_weight = cfg['loss_cls_weight']
|
|
|
@@ -47,11 +52,26 @@ class Criterion(object):
|
|
|
return loss_box
|
|
|
|
|
|
|
|
|
+ def loss_bboxes_aux(self, pred_reg, gt_box, anchors, stride_tensors):
|
|
|
+ # xyxy -> cxcy&bwbh
|
|
|
+ gt_cxcy = (gt_box[..., :2] + gt_box[..., 2:]) * 0.5
|
|
|
+ gt_bwbh = gt_box[..., 2:] - gt_box[..., :2]
|
|
|
+ # encode gt box
|
|
|
+ gt_cxcy_encode = (gt_cxcy - anchors) / stride_tensors
|
|
|
+ gt_bwbh_encode = torch.log(gt_bwbh / stride_tensors)
|
|
|
+ gt_box_encode = torch.cat([gt_cxcy_encode, gt_bwbh_encode], dim=-1)
|
|
|
+ # l1 loss
|
|
|
+ loss_box_aux = F.l1_loss(pred_reg, gt_box_encode, reduction='none')
|
|
|
+
|
|
|
+ return loss_box_aux
|
|
|
+
|
|
|
+
|
|
|
def __call__(self, outputs, targets, epoch=0):
|
|
|
"""
|
|
|
outputs['pred_obj']: List(Tensor) [B, M, 1]
|
|
|
outputs['pred_cls']: List(Tensor) [B, M, C]
|
|
|
outputs['pred_box']: List(Tensor) [B, M, 4]
|
|
|
+ outputs['pred_box']: List(Tensor) [B, M, 4]
|
|
|
outputs['strides']: List(Int) [8, 16, 32] output stride
|
|
|
targets: (List) [dict{'boxes': [...],
|
|
|
'labels': [...],
|
|
|
@@ -120,16 +140,16 @@ class Criterion(object):
|
|
|
torch.distributed.all_reduce(num_fgs)
|
|
|
num_fgs = (num_fgs / get_world_size()).clamp(1.0)
|
|
|
|
|
|
- # ------------------ objecntness loss ------------------
|
|
|
+ # ------------------ Objecntness loss ------------------
|
|
|
loss_obj = self.loss_objectness(obj_preds.view(-1, 1), obj_targets.float())
|
|
|
loss_obj = loss_obj.sum() / num_fgs
|
|
|
|
|
|
- # ------------------ classification loss ------------------
|
|
|
+ # ------------------ Classification loss ------------------
|
|
|
cls_preds_pos = cls_preds.view(-1, self.num_classes)[fg_masks]
|
|
|
loss_cls = self.loss_classes(cls_preds_pos, cls_targets)
|
|
|
loss_cls = loss_cls.sum() / num_fgs
|
|
|
|
|
|
- # ------------------ regression loss ------------------
|
|
|
+ # ------------------ Regression loss ------------------
|
|
|
box_preds_pos = box_preds.view(-1, 4)[fg_masks]
|
|
|
loss_box = self.loss_bboxes(box_preds_pos, box_targets)
|
|
|
loss_box = loss_box.sum() / num_fgs
|
|
|
@@ -146,11 +166,48 @@ class Criterion(object):
|
|
|
losses = losses
|
|
|
)
|
|
|
|
|
|
+ # ------------------ Aux regression loss ------------------
|
|
|
+ loss_box_aux = None
|
|
|
+ if epoch >= (self.max_epoch - self.no_aug_epoch - 1):
|
|
|
+ ## reg_preds
|
|
|
+ reg_preds = torch.cat(outputs['pred_reg'], dim=1)
|
|
|
+ reg_preds_pos = reg_preds.view(-1, 4)[fg_masks]
|
|
|
+ ## anchor tensors
|
|
|
+ anchors_tensors = torch.cat(outputs['anchors'], dim=0)[None].repeat(bs, 1, 1)
|
|
|
+ anchors_tensors_pos = anchors_tensors.view(-1, 2)[fg_masks]
|
|
|
+ ## stride tensors
|
|
|
+ stride_tensors = torch.cat(outputs['stride_tensors'], dim=0)[None].repeat(bs, 1, 1)
|
|
|
+ stride_tensors_pos = stride_tensors.view(-1, 1)[fg_masks]
|
|
|
+ ## aux loss
|
|
|
+ loss_box_aux = self.loss_bboxes_aux(reg_preds_pos, box_targets, anchors_tensors_pos, stride_tensors_pos)
|
|
|
+ loss_box_aux = loss_box_aux.sum() / num_fgs
|
|
|
+
|
|
|
+ losses += loss_box_aux
|
|
|
+
|
|
|
+ # Loss dict
|
|
|
+ if loss_box_aux is None:
|
|
|
+ loss_dict = dict(
|
|
|
+ loss_obj = loss_obj,
|
|
|
+ loss_cls = loss_cls,
|
|
|
+ loss_box = loss_box,
|
|
|
+ losses = losses
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ loss_dict = dict(
|
|
|
+ loss_obj = loss_obj,
|
|
|
+ loss_cls = loss_cls,
|
|
|
+ loss_box = loss_box,
|
|
|
+ loss_box_aux = loss_box_aux,
|
|
|
+ losses = losses
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
return loss_dict
|
|
|
|
|
|
|
|
|
-def build_criterion(cfg, device, num_classes):
|
|
|
+def build_criterion(args, cfg, device, num_classes):
|
|
|
criterion = Criterion(
|
|
|
+ args=args,
|
|
|
cfg=cfg,
|
|
|
device=device,
|
|
|
num_classes=num_classes
|