|
|
@@ -4,7 +4,7 @@ import torch.nn.functional as F
|
|
|
from utils.box_ops import bbox2dist, get_ious
|
|
|
from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
|
|
|
|
|
|
-from .matcher import build_matcher
|
|
|
+from .matcher import SimOTA
|
|
|
|
|
|
|
|
|
# ----------------------- Criterion for training -----------------------
|
|
|
@@ -16,25 +16,17 @@ class Criterion(object):
|
|
|
self.num_classes = num_classes
|
|
|
self.max_epoch = args.max_epoch
|
|
|
self.no_aug_epoch = args.no_aug_epoch
|
|
|
- self.use_ema_update = cfg['ema_update']
|
|
|
- self.loss_box_aux = cfg['loss_box_aux']
|
|
|
# ---------------- Loss weight ----------------
|
|
|
- loss_weights = cfg['loss_weights'][cfg['matcher']]
|
|
|
- self.loss_cls_weight = loss_weights['loss_cls_weight']
|
|
|
- self.loss_box_weight = loss_weights['loss_box_weight']
|
|
|
- self.loss_dfl_weight = loss_weights['loss_dfl_weight']
|
|
|
+ self.loss_box_aux = cfg['loss_box_aux']
|
|
|
+ self.loss_cls_weight = cfg['loss_cls_weight']
|
|
|
+ self.loss_box_weight = cfg['loss_box_weight']
|
|
|
+ self.loss_dfl_weight = cfg['loss_dfl_weight']
|
|
|
# ---------------- Matcher ----------------
|
|
|
## Aligned SimOTA assigner
|
|
|
- self.matcher = build_matcher(cfg, num_classes)
|
|
|
-
|
|
|
- def ema_update(self, name: str, value, initial_value, momentum=0.9):
|
|
|
- if hasattr(self, name):
|
|
|
- old = getattr(self, name)
|
|
|
- else:
|
|
|
- old = initial_value
|
|
|
- new = old * momentum + value * (1 - momentum)
|
|
|
- setattr(self, name, new)
|
|
|
- return new
|
|
|
+ self.matcher_hpy = cfg['matcher_hpy']
|
|
|
+ self.matcher = SimOTA(num_classes = num_classes,
|
|
|
+ center_sampling_radius = self.matcher_hpy['center_sampling_radius'],
|
|
|
+ topk_candidate = self.matcher_hpy['topk_candidate'])
|
|
|
|
|
|
# ----------------- Loss functions -----------------
|
|
|
def loss_classes(self, pred_cls, gt_score):
|
|
|
@@ -117,7 +109,7 @@ class Criterion(object):
|
|
|
return loss_box_aux
|
|
|
|
|
|
# ----------------- Main process -----------------
|
|
|
- def loss_simota(self, outputs, targets, epoch=0):
|
|
|
+ def compute_loss1(self, outputs, targets, epoch=0):
|
|
|
bs = outputs['pred_cls'][0].shape[0]
|
|
|
device = outputs['pred_cls'][0].device
|
|
|
fpn_strides = outputs['strides']
|
|
|
@@ -177,22 +169,16 @@ class Criterion(object):
|
|
|
if is_dist_avail_and_initialized():
|
|
|
torch.distributed.all_reduce(num_fgs)
|
|
|
num_fgs = (num_fgs / get_world_size()).clamp(1.0)
|
|
|
-
|
|
|
- # update loss normalizer with EMA
|
|
|
- if self.use_ema_update:
|
|
|
- normalizer = self.ema_update("loss_normalizer", max(num_fgs, 1), 100)
|
|
|
- else:
|
|
|
- normalizer = num_fgs
|
|
|
|
|
|
# ------------------ Classification loss ------------------
|
|
|
cls_preds = cls_preds.view(-1, self.num_classes)
|
|
|
loss_cls = self.loss_classes(cls_preds, cls_targets)
|
|
|
- loss_cls = loss_cls.sum() / normalizer
|
|
|
+ loss_cls = loss_cls.sum() / num_fgs
|
|
|
|
|
|
# ------------------ Regression loss ------------------
|
|
|
box_preds_pos = box_preds.view(-1, 4)[fg_masks]
|
|
|
loss_box = self.loss_bboxes(box_preds_pos, box_targets)
|
|
|
- loss_box = loss_box.sum() / normalizer
|
|
|
+ loss_box = loss_box.sum() / num_fgs
|
|
|
|
|
|
# ------------------ Distribution focal loss ------------------
|
|
|
## process anchors
|
|
|
@@ -207,7 +193,7 @@ class Criterion(object):
|
|
|
strides_pos = strides[fg_masks]
|
|
|
## compute dfl
|
|
|
loss_dfl = self.loss_dfl(reg_preds_pos, box_targets, anchors_pos, strides_pos)
|
|
|
- loss_dfl = loss_dfl.sum() / normalizer
|
|
|
+ loss_dfl = loss_dfl.sum() / num_fgs
|
|
|
|
|
|
# total loss
|
|
|
losses = self.loss_cls_weight * loss_cls + \
|
|
|
@@ -228,7 +214,7 @@ class Criterion(object):
|
|
|
delta_preds_pos = delta_preds.view(-1, 4)[fg_masks]
|
|
|
## aux loss
|
|
|
loss_box_aux = self.loss_bboxes_aux(delta_preds_pos, box_targets, anchors_pos, strides_pos)
|
|
|
- loss_box_aux = loss_box_aux.sum() / normalizer
|
|
|
+ loss_box_aux = loss_box_aux.sum() / num_fgs
|
|
|
|
|
|
losses += loss_box_aux
|
|
|
loss_dict['loss_box_aux'] = loss_box_aux
|
|
|
@@ -236,19 +222,12 @@ class Criterion(object):
|
|
|
|
|
|
return loss_dict
|
|
|
|
|
|
- def loss_aligned_simota(self, outputs, targets, epoch=0):
|
|
|
- """
|
|
|
- outputs['pred_cls']: List(Tensor) [B, M, C]
|
|
|
- outputs['pred_box']: List(Tensor) [B, M, 4]
|
|
|
- outputs['strides']: List(Int) [8, 16, 32] output stride
|
|
|
- targets: (List) [dict{'boxes': [...],
|
|
|
- 'labels': [...],
|
|
|
- 'orig_size': ...}, ...]
|
|
|
- """
|
|
|
+ def compute_loss2(self, outputs, targets, epoch=0):
|
|
|
bs = outputs['pred_cls'][0].shape[0]
|
|
|
device = outputs['pred_cls'][0].device
|
|
|
fpn_strides = outputs['strides']
|
|
|
anchors = outputs['anchors']
|
|
|
+ num_anchors = sum([ab.shape[0] for ab in anchors])
|
|
|
# preds: [B, M, C]
|
|
|
cls_preds = torch.cat(outputs['pred_cls'], dim=1)
|
|
|
reg_preds = torch.cat(outputs['pred_reg'], dim=1)
|
|
|
@@ -257,54 +236,65 @@ class Criterion(object):
|
|
|
# --------------- label assignment ---------------
|
|
|
cls_targets = []
|
|
|
box_targets = []
|
|
|
- assign_metrics = []
|
|
|
+ iou_targets = []
|
|
|
+ fg_masks = []
|
|
|
for batch_idx in range(bs):
|
|
|
- tgt_labels = targets[batch_idx]["labels"].to(device) # [N,]
|
|
|
- tgt_bboxes = targets[batch_idx]["boxes"].to(device) # [N, 4]
|
|
|
- # label assignment
|
|
|
- assigned_result = self.matcher(fpn_strides=fpn_strides,
|
|
|
- anchors=anchors,
|
|
|
- pred_cls=cls_preds[batch_idx].detach(),
|
|
|
- pred_box=box_preds[batch_idx].detach(),
|
|
|
- gt_labels=tgt_labels,
|
|
|
- gt_bboxes=tgt_bboxes
|
|
|
- )
|
|
|
- cls_targets.append(assigned_result['assigned_labels'])
|
|
|
- box_targets.append(assigned_result['assigned_bboxes'])
|
|
|
- assign_metrics.append(assigned_result['assign_metrics'])
|
|
|
-
|
|
|
- cls_targets = torch.cat(cls_targets, dim=0)
|
|
|
- box_targets = torch.cat(box_targets, dim=0)
|
|
|
- assign_metrics = torch.cat(assign_metrics, dim=0)
|
|
|
-
|
|
|
- # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
|
|
|
- bg_class_ind = self.num_classes
|
|
|
- pos_inds = ((cls_targets >= 0)
|
|
|
- & (cls_targets < bg_class_ind)).nonzero().squeeze(1)
|
|
|
- num_fgs = assign_metrics.sum()
|
|
|
+ tgt_labels = targets[batch_idx]["labels"].to(device)
|
|
|
+ tgt_bboxes = targets[batch_idx]["boxes"].to(device)
|
|
|
|
|
|
+ # check target
|
|
|
+ if len(tgt_labels) == 0 or tgt_bboxes.max().item() == 0.:
|
|
|
+ # There is no valid gt
|
|
|
+ cls_target = cls_preds.new_full([num_anchors], self.num_classes, dtype=torch.long)
|
|
|
+ iou_target = cls_preds.new_zeros([num_anchors])
|
|
|
+ box_target = cls_preds.new_zeros((0, 4))
|
|
|
+ fg_mask = cls_preds.new_zeros(num_anchors).bool()
|
|
|
+ else:
|
|
|
+ (
|
|
|
+ fg_mask,
|
|
|
+ assigned_labels,
|
|
|
+ assigned_ious,
|
|
|
+ assigned_indexs
|
|
|
+ ) = self.matcher(
|
|
|
+ fpn_strides = fpn_strides,
|
|
|
+ anchors = anchors,
|
|
|
+ pred_cls = cls_preds[batch_idx],
|
|
|
+ pred_box = box_preds[batch_idx],
|
|
|
+ tgt_labels = tgt_labels,
|
|
|
+ tgt_bboxes = tgt_bboxes
|
|
|
+ )
|
|
|
+ # prepare cls targets
|
|
|
+ cls_target = assigned_labels.new_full([num_anchors], self.num_classes, dtype=torch.long)
|
|
|
+ cls_target[fg_mask] = assigned_labels
|
|
|
+ iou_target = assigned_ious.new_zeros([num_anchors])
|
|
|
+ iou_target[fg_mask] = assigned_ious
|
|
|
+ # prepare box targets
|
|
|
+ box_target = tgt_bboxes[assigned_indexs]
|
|
|
+
|
|
|
+ cls_targets.append(cls_target)
|
|
|
+ box_targets.append(box_target)
|
|
|
+ iou_targets.append(iou_target)
|
|
|
+ fg_masks.append(fg_mask)
|
|
|
+
|
|
|
+ cls_targets = torch.cat(cls_targets, 0) # [M,]
|
|
|
+ box_targets = torch.cat(box_targets, 0) # [M, 4]
|
|
|
+ iou_targets = torch.cat(iou_targets, 0) # [M,]
|
|
|
+ fg_masks = torch.cat(fg_masks, 0)
|
|
|
+ num_fgs = fg_masks.sum()
|
|
|
+
|
|
|
+ # average loss normalizer across all the GPUs
|
|
|
if is_dist_avail_and_initialized():
|
|
|
torch.distributed.all_reduce(num_fgs)
|
|
|
- num_fgs = (num_fgs / get_world_size()).clamp(1.0).item()
|
|
|
+ num_fgs = (num_fgs / get_world_size()).clamp(1.0)
|
|
|
|
|
|
- # update loss normalizer with EMA
|
|
|
- if self.use_ema_update:
|
|
|
- normalizer = self.ema_update("loss_normalizer", max(num_fgs, 1), 100)
|
|
|
- else:
|
|
|
- normalizer = num_fgs
|
|
|
-
|
|
|
- # ---------------------------- Classification loss ----------------------------
|
|
|
+ # ------------------ Classification loss ------------------
|
|
|
cls_preds = cls_preds.view(-1, self.num_classes)
|
|
|
- loss_cls = self.loss_classes_qfl(cls_preds, (cls_targets, assign_metrics))
|
|
|
- loss_cls = loss_cls.sum() / normalizer
|
|
|
+ loss_cls = self.loss_classes_qfl(cls_preds, (cls_targets, iou_targets))
|
|
|
+ loss_cls = loss_cls.sum() / num_fgs
|
|
|
|
|
|
- # ---------------------------- Regression loss ----------------------------
|
|
|
- box_preds_pos = box_preds.view(-1, 4)[pos_inds]
|
|
|
- box_targets_pos = box_targets[pos_inds]
|
|
|
- box_weight_pos = assign_metrics[pos_inds]
|
|
|
- loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos)
|
|
|
- loss_box *= box_weight_pos
|
|
|
- loss_box = loss_box.sum() / normalizer
|
|
|
+ # ------------------ Regression loss ------------------
|
|
|
+ loss_box = self.loss_bboxes(box_preds.view(-1, 4)[fg_masks], box_targets)
|
|
|
+ loss_box = loss_box.sum() / num_fgs
|
|
|
|
|
|
# ------------------ Distribution focal loss ------------------
|
|
|
## process anchors
|
|
|
@@ -314,13 +304,12 @@ class Criterion(object):
|
|
|
strides = torch.cat(outputs['stride_tensor'], dim=0)
|
|
|
strides = strides.unsqueeze(0).repeat(bs, 1, 1).view(-1, 1)
|
|
|
## fg preds
|
|
|
- reg_preds_pos = reg_preds.view(-1, 4*self.cfg['reg_max'])[pos_inds]
|
|
|
- anchors_pos = anchors[pos_inds]
|
|
|
- strides_pos = strides[pos_inds]
|
|
|
+ reg_preds_pos = reg_preds.view(-1, 4*self.cfg['reg_max'])[fg_masks]
|
|
|
+ anchors_pos = anchors[fg_masks]
|
|
|
+ strides_pos = strides[fg_masks]
|
|
|
## compute dfl
|
|
|
- loss_dfl = self.loss_dfl(reg_preds_pos, box_targets_pos, anchors_pos, strides_pos)
|
|
|
- loss_dfl *= box_weight_pos
|
|
|
- loss_dfl = loss_dfl.sum() / normalizer
|
|
|
+ loss_dfl = self.loss_dfl(reg_preds_pos, box_targets, anchors_pos, strides_pos)
|
|
|
+ loss_dfl = loss_dfl.sum() / num_fgs
|
|
|
|
|
|
# total loss
|
|
|
losses = self.loss_cls_weight * loss_cls + \
|
|
|
@@ -338,22 +327,26 @@ class Criterion(object):
|
|
|
if epoch >= (self.max_epoch - self.no_aug_epoch - 1) and self.loss_box_aux:
|
|
|
## delta_preds
|
|
|
delta_preds = torch.cat(outputs['pred_delta'], dim=1)
|
|
|
- delta_preds_pos = delta_preds.view(-1, 4)[pos_inds]
|
|
|
+ delta_preds_pos = delta_preds.view(-1, 4)[fg_masks]
|
|
|
## aux loss
|
|
|
- loss_box_aux = self.loss_bboxes_aux(delta_preds_pos, box_targets_pos, anchors_pos, strides_pos)
|
|
|
- loss_box_aux = loss_box_aux.sum() / normalizer
|
|
|
+ loss_box_aux = self.loss_bboxes_aux(delta_preds_pos, box_targets, anchors_pos, strides_pos)
|
|
|
+ loss_box_aux = loss_box_aux.sum() / num_fgs
|
|
|
|
|
|
losses += loss_box_aux
|
|
|
loss_dict['loss_box_aux'] = loss_box_aux
|
|
|
|
|
|
+
|
|
|
return loss_dict
|
|
|
|
|
|
|
|
|
def __call__(self, outputs, targets, epoch=0):
|
|
|
- if self.cfg['matcher'] == "simota":
|
|
|
- return self.loss_simota(outputs, targets, epoch)
|
|
|
- elif self.cfg['matcher'] == "aligned_simota":
|
|
|
- return self.loss_aligned_simota(outputs, targets, epoch)
|
|
|
+ if self.cfg['cls_loss'] == "bce":
|
|
|
+ return self.compute_loss1(outputs, targets, epoch)
|
|
|
+ elif self.cfg['cls_loss'] == "qfl":
|
|
|
+ self.loss_box_weight = 2.0
|
|
|
+ return self.compute_loss2(outputs, targets, epoch)
|
|
|
+ else:
|
|
|
+ raise NotImplementedError
|
|
|
|
|
|
|
|
|
def build_criterion(args, cfg, device, num_classes):
|