| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 |
- import torch
- import torch.nn.functional as F
- from utils.box_ops import get_ious
- from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
- from .matcher import Yolov5Matcher
- class SetCriterion(object):
- def __init__(self, cfg):
- self.cfg = cfg
- self.num_classes = cfg.num_classes
- self.loss_obj_weight = cfg.loss_obj
- self.loss_cls_weight = cfg.loss_cls
- self.loss_box_weight = cfg.loss_box
- # matcher
- anchor_size = cfg.anchor_size[0] + cfg.anchor_size[1] + cfg.anchor_size[2]
- self.matcher = Yolov5Matcher(cfg.num_classes, 3, anchor_size, cfg.anchor_thresh)
- def loss_objectness(self, pred_obj, gt_obj):
- loss_obj = F.binary_cross_entropy_with_logits(pred_obj, gt_obj, reduction='none')
- return loss_obj
-
- def loss_classes(self, pred_cls, gt_label):
- loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_label, reduction='none')
- return loss_cls
- def loss_bboxes(self, pred_box, gt_box):
- # regression loss
- ious = get_ious(pred_box,
- gt_box,
- box_mode="xyxy",
- iou_type='giou')
- loss_box = 1.0 - ious
- return loss_box, ious
- def __call__(self, outputs, targets):
- device = outputs['pred_cls'][0].device
- fpn_strides = outputs['strides']
- fmp_sizes = outputs['fmp_sizes']
- (
- gt_objectness,
- gt_classes,
- gt_bboxes,
- ) = self.matcher(fmp_sizes=fmp_sizes,
- fpn_strides=fpn_strides,
- targets=targets)
- # List[B, M, C] -> [B, M, C] -> [BM, C]
- pred_obj = torch.cat(outputs['pred_obj'], dim=1).view(-1) # [BM,]
- pred_cls = torch.cat(outputs['pred_cls'], dim=1).view(-1, self.num_classes) # [BM, C]
- pred_box = torch.cat(outputs['pred_box'], dim=1).view(-1, 4) # [BM, 4]
-
- gt_objectness = gt_objectness.view(-1).to(device).float() # [BM,]
- gt_classes = gt_classes.view(-1, self.num_classes).to(device).float() # [BM, C]
- gt_bboxes = gt_bboxes.view(-1, 4).to(device).float() # [BM, 4]
- pos_masks = (gt_objectness > 0)
- num_fgs = pos_masks.sum()
- if is_dist_avail_and_initialized():
- torch.distributed.all_reduce(num_fgs)
- num_fgs = (num_fgs / get_world_size()).clamp(1.0)
- # box loss
- pred_box_pos = pred_box[pos_masks]
- gt_bboxes_pos = gt_bboxes[pos_masks]
- loss_box, ious = self.loss_bboxes(pred_box_pos, gt_bboxes_pos)
- loss_box = loss_box.sum() / num_fgs
-
- # cls loss
- pred_cls_pos = pred_cls[pos_masks]
- gt_classes_pos = gt_classes[pos_masks] * ious.unsqueeze(-1).clamp(0.)
- loss_cls = self.loss_classes(pred_cls_pos, gt_classes_pos)
- loss_cls = loss_cls.sum() / num_fgs
- # obj loss
- loss_obj = self.loss_objectness(pred_obj, gt_objectness)
- loss_obj = loss_obj.sum() / num_fgs
- # total loss
- losses = self.loss_obj_weight * loss_obj + \
- self.loss_cls_weight * loss_cls + \
- self.loss_box_weight * loss_box
- loss_dict = dict(
- loss_obj = loss_obj,
- loss_cls = loss_cls,
- loss_box = loss_box,
- losses = losses
- )
- return loss_dict
-
-
- if __name__ == "__main__":
- pass
|