loss.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import torch
  2. import torch.nn.functional as F
  3. from .matcher import Yolov2Matcher
  4. from utils.box_ops import get_ious
  5. from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
  6. class Criterion(object):
  7. def __init__(self, cfg, device, num_classes=80):
  8. self.cfg = cfg
  9. self.device = device
  10. self.num_classes = num_classes
  11. # loss weight
  12. self.loss_obj_weight = cfg['loss_obj_weight']
  13. self.loss_cls_weight = cfg['loss_cls_weight']
  14. self.loss_box_weight = cfg['loss_box_weight']
  15. # matcher
  16. self.matcher = Yolov2Matcher(cfg['iou_thresh'], num_classes, cfg['anchor_size'])
  17. def loss_objectness(self, pred_obj, gt_obj):
  18. loss_obj = F.binary_cross_entropy_with_logits(pred_obj, gt_obj, reduction='none')
  19. return loss_obj
  20. def loss_classes(self, pred_cls, gt_label):
  21. loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_label, reduction='none')
  22. return loss_cls
  23. def loss_bboxes(self, pred_box, gt_box):
  24. # regression loss
  25. ious = get_ious(pred_box,
  26. gt_box,
  27. box_mode="xyxy",
  28. iou_type='giou')
  29. loss_box = 1.0 - ious
  30. return loss_box, ious
  31. def __call__(self, outputs, targets):
  32. device = outputs['pred_cls'].device
  33. stride = outputs['stride']
  34. fmp_size = outputs['fmp_size']
  35. (
  36. gt_objectness,
  37. gt_classes,
  38. gt_bboxes,
  39. ) = self.matcher(fmp_size=fmp_size,
  40. stride=stride,
  41. targets=targets)
  42. # List[B, M, C] -> [B, M, C] -> [BM, C]
  43. pred_obj = outputs['pred_obj'].view(-1) # [BM,]
  44. pred_cls = outputs['pred_cls'].view(-1, self.num_classes) # [BM, C]
  45. pred_box = outputs['pred_box'].view(-1, 4) # [BM, 4]
  46. gt_objectness = gt_objectness.view(-1).to(device).float() # [BM,]
  47. gt_classes = gt_classes.view(-1, self.num_classes).to(device).float() # [BM, C]
  48. gt_bboxes = gt_bboxes.view(-1, 4).to(device).float() # [BM, 4]
  49. pos_masks = (gt_objectness > 0)
  50. num_fgs = pos_masks.sum()
  51. if is_dist_avail_and_initialized():
  52. torch.distributed.all_reduce(num_fgs)
  53. num_fgs = (num_fgs / get_world_size()).clamp(1.0)
  54. # box loss
  55. pred_box_pos = pred_box[pos_masks]
  56. gt_bboxes_pos = gt_bboxes[pos_masks]
  57. loss_box, ious = self.loss_bboxes(pred_box_pos, gt_bboxes_pos)
  58. loss_box = loss_box.sum() / num_fgs
  59. # cls loss
  60. pred_cls_pos = pred_cls[pos_masks]
  61. gt_classes_pos = gt_classes[pos_masks] * ious.unsqueeze(-1).clamp(0.)
  62. loss_cls = self.loss_classes(pred_cls_pos, gt_classes_pos)
  63. loss_cls = loss_cls.sum() / num_fgs
  64. # obj loss
  65. loss_obj = self.loss_objectness(pred_obj, gt_objectness)
  66. loss_obj = loss_obj.sum() / num_fgs
  67. # total loss
  68. losses = self.loss_obj_weight * loss_obj + \
  69. self.loss_cls_weight * loss_cls + \
  70. self.loss_box_weight * loss_box
  71. loss_dict = dict(
  72. loss_obj = loss_obj,
  73. loss_cls = loss_cls,
  74. loss_box = loss_box,
  75. losses = losses
  76. )
  77. return loss_dict
  78. def build_criterion(cfg, device, num_classes):
  79. criterion = Criterion(
  80. cfg=cfg,
  81. device=device,
  82. num_classes=num_classes
  83. )
  84. return criterion
  85. if __name__ == "__main__":
  86. pass