loss.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import torch
  2. import torch.nn.functional as F
  3. from .matcher import YoloMatcher
  4. class Criterion(object):
  5. def __init__(self, cfg, device, num_classes=80):
  6. self.cfg = cfg
  7. self.device = device
  8. self.num_classes = num_classes
  9. self.loss_obj_weight = cfg['loss_obj_weight']
  10. self.loss_cls_weight = cfg['loss_cls_weight']
  11. self.loss_txty_weight = cfg['loss_txty_weight']
  12. self.loss_twth_weight = cfg['loss_twth_weight']
  13. # matcher
  14. self.matcher = YoloMatcher(num_classes=num_classes)
  15. def loss_objectness(self, pred_obj, gt_obj):
  16. obj_score = torch.clamp(torch.sigmoid(pred_obj), min=1e-4, max=1.0 - 1e-4)
  17. # obj loss
  18. pos_id = (gt_obj==1.0).float()
  19. pos_loss = pos_id * (obj_score - gt_obj)**2
  20. # noobj loss
  21. neg_id = (gt_obj==0.0).float()
  22. neg_loss = neg_id * (obj_score)**2
  23. # total loss
  24. loss_obj = 5.0 * pos_loss + 1.0 * neg_loss
  25. return loss_obj
  26. def loss_labels(self, pred_cls, gt_label):
  27. loss_cls = F.cross_entropy(pred_cls, gt_label, reduction='none')
  28. return loss_cls
  29. def loss_txty(self, pred_txty, gt_txty, gt_box_weight):
  30. # txty loss
  31. loss_txty = F.binary_cross_entropy_with_logits(
  32. pred_txty, gt_txty, reduction='none').sum(-1)
  33. loss_txty *= gt_box_weight
  34. return loss_txty
  35. def loss_twth(self, pred_twth, gt_twth, gt_box_weight):
  36. # twth loss
  37. loss_twth = F.mse_loss(pred_twth, gt_twth, reduction='none').sum(-1)
  38. loss_twth *= gt_box_weight
  39. return loss_twth
  40. def __call__(self, outputs, targets):
  41. device = outputs['pred_cls'][0].device
  42. stride = outputs['stride']
  43. img_size = outputs['img_size']
  44. (
  45. gt_objectness,
  46. gt_labels,
  47. gt_bboxes,
  48. gt_box_weight
  49. ) = self.matcher(img_size=img_size,
  50. stride=stride,
  51. targets=targets)
  52. # List[B, M, C] -> [B, M, C] -> [BM, C]
  53. batch_size = outputs['pred_obj'].shape[0]
  54. pred_obj = outputs['pred_obj'].view(-1)
  55. pred_cls = outputs['pred_cls'].view(-1, self.num_classes)
  56. pred_txty = outputs['pred_txty'].view(-1, 2)
  57. pred_twth = outputs['pred_twth'].view(-1, 2)
  58. gt_objectness = gt_objectness.view(-1).to(device).float()
  59. gt_labels = gt_labels.view(-1).to(device).long()
  60. gt_bboxes = gt_bboxes.view(-1, 4).to(device).float()
  61. gt_box_weight = gt_box_weight.view(-1).to(device).float()
  62. pos_masks = (gt_objectness > 0)
  63. # objectness loss
  64. loss_obj = self.loss_objectness(pred_obj, gt_objectness)
  65. loss_obj = loss_obj.sum() / batch_size
  66. # classification loss
  67. pred_cls_pos = pred_cls[pos_masks]
  68. gt_labels_pos = gt_labels[pos_masks]
  69. loss_cls = self.loss_labels(pred_cls_pos, gt_labels_pos)
  70. loss_cls = loss_cls.sum() / batch_size
  71. # txty loss
  72. pred_txty_pos = pred_txty[pos_masks]
  73. gt_txty_pos = gt_bboxes[pos_masks][..., :2]
  74. gt_box_weight_pos = gt_box_weight[pos_masks]
  75. loss_txty = self.loss_txty(pred_txty_pos, gt_txty_pos, gt_box_weight_pos)
  76. loss_txty = loss_txty.sum() / batch_size
  77. # twth loss
  78. pred_twth_pos = pred_twth[pos_masks]
  79. gt_twth_pos = gt_bboxes[pos_masks][..., 2:]
  80. loss_twth = self.loss_twth(pred_twth_pos, gt_twth_pos, gt_box_weight_pos)
  81. loss_twth = loss_twth.sum() / batch_size
  82. # total loss
  83. losses = self.loss_obj_weight * loss_obj + \
  84. self.loss_cls_weight * loss_cls + \
  85. self.loss_txty_weight * loss_txty + \
  86. self.loss_twth_weight * loss_twth
  87. loss_dict = dict(
  88. loss_obj = loss_obj,
  89. loss_cls = loss_cls,
  90. loss_txty = loss_txty,
  91. loss_twth = loss_twth,
  92. losses = losses
  93. )
  94. return loss_dict
  95. def build_criterion(cfg, device, num_classes):
  96. criterion = Criterion(
  97. cfg=cfg,
  98. device=device,
  99. num_classes=num_classes
  100. )
  101. return criterion
  102. if __name__ == "__main__":
  103. pass