loss.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import torch
  2. import torch.nn.functional as F
  3. from .matcher import Yolov3Matcher
  4. from utils.box_ops import get_ious
  5. from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
  6. class SetCriterion(object):
  7. def __init__(self, cfg):
  8. self.cfg = cfg
  9. self.num_classes = cfg.num_classes
  10. # loss weight
  11. self.loss_obj_weight = cfg.loss_obj
  12. self.loss_cls_weight = cfg.loss_cls
  13. self.loss_box_weight = cfg.loss_box
  14. # matcher
  15. self.matcher = Yolov3Matcher(self.num_classes, 3, cfg.anchor_size, cfg.iou_thresh)
  16. def loss_objectness(self, pred_obj, gt_obj):
  17. loss_obj = F.binary_cross_entropy_with_logits(pred_obj, gt_obj, reduction='none')
  18. return loss_obj
  19. def loss_classes(self, pred_cls, gt_label):
  20. loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_label, reduction='none')
  21. return loss_cls
  22. def loss_bboxes(self, pred_box, gt_box):
  23. # regression loss
  24. ious = get_ious(pred_box,
  25. gt_box,
  26. box_mode="xyxy",
  27. iou_type='giou')
  28. loss_box = 1.0 - ious
  29. return loss_box, ious
  30. def __call__(self, outputs, targets):
  31. # label assignment
  32. (
  33. gt_objectness,
  34. gt_classes,
  35. gt_bboxes,
  36. ) = self.matcher(fmp_sizes = outputs['fmp_sizes'],
  37. fpn_strides = outputs['strides'],
  38. targets = targets)
  39. # List[B, M, C] -> [B, M, C] -> [BM, C]
  40. pred_obj = torch.cat(outputs['pred_obj'], dim=1).view(-1) # [BM,]
  41. pred_cls = torch.cat(outputs['pred_cls'], dim=1).view(-1, self.num_classes) # [BM, C]
  42. pred_box = torch.cat(outputs['pred_box'], dim=1).view(-1, 4) # [BM, 4]
  43. device = pred_box.device
  44. gt_objectness = gt_objectness.view(-1).to(device).float() # [BM,]
  45. gt_classes = gt_classes.view(-1, self.num_classes).to(device).float() # [BM, C]
  46. gt_bboxes = gt_bboxes.view(-1, 4).to(device).float() # [BM, 4]
  47. pos_masks = (gt_objectness > 0)
  48. num_fgs = pos_masks.sum()
  49. if is_dist_avail_and_initialized():
  50. torch.distributed.all_reduce(num_fgs)
  51. num_fgs = (num_fgs / get_world_size()).clamp(1.0)
  52. # box loss
  53. pred_box_pos = pred_box[pos_masks]
  54. gt_bboxes_pos = gt_bboxes[pos_masks]
  55. loss_box, ious = self.loss_bboxes(pred_box_pos, gt_bboxes_pos)
  56. loss_box = loss_box.sum() / num_fgs
  57. # cls loss
  58. pred_cls_pos = pred_cls[pos_masks]
  59. gt_classes_pos = gt_classes[pos_masks] * ious.unsqueeze(-1).clamp(0.)
  60. loss_cls = self.loss_classes(pred_cls_pos, gt_classes_pos)
  61. loss_cls = loss_cls.sum() / num_fgs
  62. # obj loss
  63. loss_obj = self.loss_objectness(pred_obj, gt_objectness)
  64. loss_obj = loss_obj.sum() / num_fgs
  65. # total loss
  66. losses = self.loss_obj_weight * loss_obj + \
  67. self.loss_cls_weight * loss_cls + \
  68. self.loss_box_weight * loss_box
  69. loss_dict = dict(
  70. loss_obj = loss_obj,
  71. loss_cls = loss_cls,
  72. loss_box = loss_box,
  73. losses = losses
  74. )
  75. return loss_dict
  76. if __name__ == "__main__":
  77. pass