import math import torch import torch.nn.functional as F from utils.box_ops import * @torch.no_grad() def get_ious_and_iou_loss(inputs, targets, weight=None, box_mode="xyxy", loss_type="iou", reduction="none"): """ Compute iou loss of type ['iou', 'giou', 'linear_iou'] Args: inputs (tensor): pred values targets (tensor): target values weight (tensor): loss weight box_mode (str): 'xyxy' or 'ltrb', 'ltrb' is currently supported. loss_type (str): 'giou' or 'iou' or 'linear_iou' reduction (str): reduction manner Returns: loss (tensor): computed iou loss. """ if box_mode == "ltrb": inputs = torch.cat((-inputs[..., :2], inputs[..., 2:]), dim=-1) targets = torch.cat((-targets[..., :2], targets[..., 2:]), dim=-1) elif box_mode != "xyxy": raise NotImplementedError eps = torch.finfo(torch.float32).eps inputs_area = (inputs[..., 2] - inputs[..., 0]).clamp_(min=0) \ * (inputs[..., 3] - inputs[..., 1]).clamp_(min=0) targets_area = (targets[..., 2] - targets[..., 0]).clamp_(min=0) \ * (targets[..., 3] - targets[..., 1]).clamp_(min=0) w_intersect = (torch.min(inputs[..., 2], targets[..., 2]) - torch.max(inputs[..., 0], targets[..., 0])).clamp_(min=0) h_intersect = (torch.min(inputs[..., 3], targets[..., 3]) - torch.max(inputs[..., 1], targets[..., 1])).clamp_(min=0) area_intersect = w_intersect * h_intersect area_union = targets_area + inputs_area - area_intersect ious = area_intersect / area_union.clamp(min=eps) if loss_type == "iou": loss = -ious.clamp(min=eps).log() elif loss_type == "linear_iou": loss = 1 - ious elif loss_type == "giou": g_w_intersect = torch.max(inputs[..., 2], targets[..., 2]) \ - torch.min(inputs[..., 0], targets[..., 0]) g_h_intersect = torch.max(inputs[..., 3], targets[..., 3]) \ - torch.min(inputs[..., 1], targets[..., 1]) ac_uion = g_w_intersect * g_h_intersect gious = ious - (ac_uion - area_union) / ac_uion.clamp(min=eps) loss = 1 - gious else: raise NotImplementedError if weight is not None: loss = loss * weight.view(loss.size()) if reduction == "mean": loss = loss.sum() / max(weight.sum().item(), eps) else: if reduction == "mean": loss = loss.mean() if reduction == "sum": loss = loss.sum() return ious, loss class FcosMatcher(object): """ This code referenced to https://github.com/Megvii-BaseDetection/cvpods """ def __init__(self, num_classes, center_sampling_radius, object_sizes_of_interest, box_weights=[1, 1, 1, 1]): self.num_classes = num_classes self.center_sampling_radius = center_sampling_radius self.object_sizes_of_interest = object_sizes_of_interest self.box_weightss = box_weights def get_deltas(self, anchors, boxes): """ Get box regression transformation deltas (dl, dt, dr, db) that can be used to transform the `anchors` into the `boxes`. That is, the relation ``boxes == self.apply_deltas(deltas, anchors)`` is true. Args: anchors (Tensor): anchors, e.g., feature map coordinates boxes (Tensor): target of the transformation, e.g., ground-truth boxes. """ assert isinstance(anchors, torch.Tensor), type(anchors) assert isinstance(boxes, torch.Tensor), type(boxes) deltas = torch.cat((anchors - boxes[..., :2], boxes[..., 2:] - anchors), dim=-1) * anchors.new_tensor(self.box_weightss) return deltas @torch.no_grad() def __call__(self, fpn_strides, anchors, targets): """ fpn_strides: (List) List[8, 16, 32, ...] stride of network output. anchors: (List of Tensor) List[F, M, 2], F = num_fpn_levels targets: (Dict) dict{'boxes': [...], 'labels': [...], 'orig_size': ...} """ gt_classes = [] gt_anchors_deltas = [] gt_centerness = [] device = anchors[0].device # List[F, M, 2] -> [M, 2] anchors_over_all_feature_maps = torch.cat(anchors, dim=0).to(device) for targets_per_image in targets: # generate object_sizes_of_interest: List[[M, 2]] object_sizes_of_interest = [anchors_i.new_tensor(scale_range).unsqueeze(0).expand(anchors_i.size(0), -1) for anchors_i, scale_range in zip(anchors, self.object_sizes_of_interest)] # List[F, M, 2] -> [M, 2], M = M1 + M2 + ... + MF object_sizes_of_interest = torch.cat(object_sizes_of_interest, dim=0) # [N, 4] tgt_box = targets_per_image['boxes'].to(device) # [N, C] tgt_cls = targets_per_image['labels'].to(device) # [N, M, 4], M = M1 + M2 + ... + MF deltas = self.get_deltas(anchors_over_all_feature_maps, tgt_box.unsqueeze(1)) has_gt = (len(tgt_cls) > 0) if has_gt: if self.center_sampling_radius > 0: # bbox centers: [N, 2] centers = (tgt_box[..., :2] + tgt_box[..., 2:]) * 0.5 is_in_boxes = [] for stride, anchors_i in zip(fpn_strides, anchors): radius = stride * self.center_sampling_radius # [N, 4] center_boxes = torch.cat(( torch.max(centers - radius, tgt_box[:, :2]), torch.min(centers + radius, tgt_box[:, 2:]), ), dim=-1) # [N, Mi, 4] center_deltas = self.get_deltas(anchors_i, center_boxes.unsqueeze(1)) # [N, Mi] is_in_boxes.append(center_deltas.min(dim=-1).values > 0) # [N, M], M = M1 + M2 + ... + MF is_in_boxes = torch.cat(is_in_boxes, dim=1) else: # no center sampling, it will use all the locations within a ground-truth box # [N, M], M = M1 + M2 + ... + MF is_in_boxes = deltas.min(dim=-1).values > 0 # [N, M], M = M1 + M2 + ... + MF max_deltas = deltas.max(dim=-1).values # limit the regression range for each location is_cared_in_the_level = \ (max_deltas >= object_sizes_of_interest[None, :, 0]) & \ (max_deltas <= object_sizes_of_interest[None, :, 1]) # [N,] tgt_box_area = (tgt_box[:, 2] - tgt_box[:, 0]) * (tgt_box[:, 3] - tgt_box[:, 1]) # [N,] -> [N, 1] -> [N, M] gt_positions_area = tgt_box_area.unsqueeze(1).repeat( 1, anchors_over_all_feature_maps.size(0)) gt_positions_area[~is_in_boxes] = math.inf gt_positions_area[~is_cared_in_the_level] = math.inf # if there are still more than one objects for a position, # we choose the one with minimal area # [M,], each element is the index of ground-truth positions_min_area, gt_matched_idxs = gt_positions_area.min(dim=0) # ground truth box regression # [M, 4] gt_anchors_reg_deltas_i = self.get_deltas( anchors_over_all_feature_maps, tgt_box[gt_matched_idxs]) # [M,] tgt_cls_i = tgt_cls[gt_matched_idxs] # anchors with area inf are treated as background. tgt_cls_i[positions_min_area == math.inf] = self.num_classes # ground truth centerness left_right = gt_anchors_reg_deltas_i[:, [0, 2]] top_bottom = gt_anchors_reg_deltas_i[:, [1, 3]] # [M,] gt_centerness_i = torch.sqrt( (left_right.min(dim=-1).values / left_right.max(dim=-1).values).clamp_(min=0) * (top_bottom.min(dim=-1).values / top_bottom.max(dim=-1).values).clamp_(min=0) ) gt_classes.append(tgt_cls_i) gt_anchors_deltas.append(gt_anchors_reg_deltas_i) gt_centerness.append(gt_centerness_i) del centers, center_boxes, deltas, max_deltas, center_deltas else: tgt_cls_i = torch.zeros(anchors_over_all_feature_maps.shape[0], device=device) + self.num_classes gt_anchors_reg_deltas_i = torch.zeros([anchors_over_all_feature_maps.shape[0], 4], device=device) gt_centerness_i = torch.zeros(anchors_over_all_feature_maps.shape[0], device=device) gt_classes.append(tgt_cls_i.long()) gt_anchors_deltas.append(gt_anchors_reg_deltas_i.float()) gt_centerness.append(gt_centerness_i.float()) # [B, M], [B, M, 4], [B, M] return torch.stack(gt_classes), torch.stack(gt_anchors_deltas), torch.stack(gt_centerness) class AlignedOTAMatcher(object): """ This code referenced to https://github.com/open-mmlab/mmyolo/models/task_modules/assigners/batch_dsl_assigner.py """ def __init__(self, num_classes, soft_center_radius=3.0, topk_candidates=13): self.num_classes = num_classes self.soft_center_radius = soft_center_radius self.topk_candidates = topk_candidates @torch.no_grad() def __call__(self, fpn_strides, anchors, pred_cls, pred_box, gt_labels, gt_bboxes): # [M,] strides = torch.cat([torch.ones_like(anchor_i[:, 0]) * stride_i for stride_i, anchor_i in zip(fpn_strides, anchors)], dim=-1) # List[F, M, 2] -> [M, 2] num_gt = len(gt_labels) anchors = torch.cat(anchors, dim=0) # check gt if num_gt == 0 or gt_bboxes.max().item() == 0.: return { 'assigned_labels': gt_labels.new_full(pred_cls[..., 0].shape, self.num_classes, dtype=torch.long), 'assigned_bboxes': gt_bboxes.new_full(pred_box.shape, 0), 'assign_metrics': gt_bboxes.new_full(pred_cls[..., 0].shape, 0) } # get inside points: [N, M] is_in_gt = self.find_inside_points(gt_bboxes, anchors) valid_mask = is_in_gt.sum(dim=0) > 0 # [M,] # ----------------------------------- soft center prior ----------------------------------- gt_center = (gt_bboxes[..., :2] + gt_bboxes[..., 2:]) / 2.0 distance = (anchors.unsqueeze(0) - gt_center.unsqueeze(1) ).pow(2).sum(-1).sqrt() / strides.unsqueeze(0) # [N, M] distance = distance * valid_mask.unsqueeze(0) soft_center_prior = torch.pow(10, distance - self.soft_center_radius) # ----------------------------------- regression cost ----------------------------------- pair_wise_ious, _ = box_iou(gt_bboxes, pred_box) # [N, M] pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) * 3.0 # ----------------------------------- classification cost ----------------------------------- ## select the predicted scores corresponded to the gt_labels pairwise_pred_scores = pred_cls.permute(1, 0) # [M, C] -> [C, M] pairwise_pred_scores = pairwise_pred_scores[gt_labels.long(), :].float() # [N, M] ## scale factor scale_factor = (pair_wise_ious - pairwise_pred_scores.sigmoid()).abs().pow(2.0) ## cls cost pair_wise_cls_loss = F.binary_cross_entropy_with_logits( pairwise_pred_scores, pair_wise_ious, reduction="none") * scale_factor # [N, M] del pairwise_pred_scores ## foreground cost matrix cost_matrix = pair_wise_cls_loss + pair_wise_ious_loss + soft_center_prior max_pad_value = torch.ones_like(cost_matrix) * 1e9 cost_matrix = torch.where(valid_mask[None].repeat(num_gt, 1), # [N, M] cost_matrix, max_pad_value) # ----------------------------------- dynamic label assignment ----------------------------------- matched_pred_ious, matched_gt_inds, fg_mask_inboxes = self.dynamic_k_matching( cost_matrix, pair_wise_ious, num_gt) del pair_wise_cls_loss, cost_matrix, pair_wise_ious, pair_wise_ious_loss # -----------------------------------process assigned labels ----------------------------------- assigned_labels = gt_labels.new_full(pred_cls[..., 0].shape, self.num_classes) # [M,] assigned_labels[fg_mask_inboxes] = gt_labels[matched_gt_inds].squeeze(-1) assigned_labels = assigned_labels.long() # [M,] assigned_bboxes = gt_bboxes.new_full(pred_box.shape, 0) # [M, 4] assigned_bboxes[fg_mask_inboxes] = gt_bboxes[matched_gt_inds] # [M, 4] assign_metrics = gt_bboxes.new_full(pred_cls[..., 0].shape, 0) # [M,] assign_metrics[fg_mask_inboxes] = matched_pred_ious # [M,] assigned_dict = dict( assigned_labels=assigned_labels, assigned_bboxes=assigned_bboxes, assign_metrics=assign_metrics ) return assigned_dict def find_inside_points(self, gt_bboxes, anchors): """ gt_bboxes: Tensor -> [N, 2] anchors: Tensor -> [M, 2] """ num_anchors = anchors.shape[0] num_gt = gt_bboxes.shape[0] anchors_expand = anchors.unsqueeze(0).repeat(num_gt, 1, 1) # [N, M, 2] gt_bboxes_expand = gt_bboxes.unsqueeze(1).repeat(1, num_anchors, 1) # [N, M, 4] # offset lt = anchors_expand - gt_bboxes_expand[..., :2] rb = gt_bboxes_expand[..., 2:] - anchors_expand bbox_deltas = torch.cat([lt, rb], dim=-1) is_in_gts = bbox_deltas.min(dim=-1).values > 0 return is_in_gts def dynamic_k_matching(self, cost_matrix, pairwise_ious, num_gt): """Use IoU and matching cost to calculate the dynamic top-k positive targets. Args: cost_matrix (Tensor): Cost matrix. pairwise_ious (Tensor): Pairwise iou matrix. num_gt (int): Number of gt. valid_mask (Tensor): Mask for valid bboxes. Returns: tuple: matched ious and gt indexes. """ matching_matrix = torch.zeros_like(cost_matrix, dtype=torch.uint8) # select candidate topk ious for dynamic-k calculation candidate_topk = min(self.topk_candidates, pairwise_ious.size(1)) topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=1) # calculate dynamic k for each gt dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1) # sorting the batch cost matirx is faster than topk _, sorted_indices = torch.sort(cost_matrix, dim=1) for gt_idx in range(num_gt): topk_ids = sorted_indices[gt_idx, :dynamic_ks[gt_idx]] matching_matrix[gt_idx, :][topk_ids] = 1 del topk_ious, dynamic_ks, topk_ids prior_match_gt_mask = matching_matrix.sum(0) > 1 if prior_match_gt_mask.sum() > 0: cost_min, cost_argmin = torch.min( cost_matrix[:, prior_match_gt_mask], dim=0) matching_matrix[:, prior_match_gt_mask] *= 0 matching_matrix[cost_argmin, prior_match_gt_mask] = 1 # get foreground mask inside box and center prior fg_mask_inboxes = matching_matrix.sum(0) > 0 matched_pred_ious = (matching_matrix * pairwise_ious).sum(0)[fg_mask_inboxes] matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0) return matched_pred_ious, matched_gt_inds, fg_mask_inboxes