loss_utils.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import torch
  2. import torch.distributed as dist
  3. from torchvision.ops.boxes import box_area
  4. # ------------------------- For box -------------------------
  5. def box_iou(boxes1, boxes2):
  6. area1 = box_area(boxes1)
  7. area2 = box_area(boxes2)
  8. lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
  9. rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
  10. wh = (rb - lt).clamp(min=0) # [N,M,2]
  11. inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
  12. union = area1[:, None] + area2 - inter
  13. iou = inter / union
  14. return iou, union
  15. def get_ious(bboxes1,
  16. bboxes2,
  17. box_mode="xyxy",
  18. iou_type="iou"):
  19. """
  20. Compute iou loss of type ['iou', 'giou', 'linear_iou']
  21. Args:
  22. inputs (tensor): pred values
  23. targets (tensor): target values
  24. weight (tensor): loss weight
  25. box_mode (str): 'xyxy' or 'ltrb', 'ltrb' is currently supported.
  26. loss_type (str): 'giou' or 'iou' or 'linear_iou'
  27. reduction (str): reduction manner
  28. Returns:
  29. loss (tensor): computed iou loss.
  30. """
  31. if box_mode == "ltrb":
  32. bboxes1 = torch.cat((-bboxes1[..., :2], bboxes1[..., 2:]), dim=-1)
  33. bboxes2 = torch.cat((-bboxes2[..., :2], bboxes2[..., 2:]), dim=-1)
  34. elif box_mode != "xyxy":
  35. raise NotImplementedError
  36. eps = torch.finfo(torch.float32).eps
  37. bboxes1_area = (bboxes1[..., 2] - bboxes1[..., 0]).clamp_(min=0) \
  38. * (bboxes1[..., 3] - bboxes1[..., 1]).clamp_(min=0)
  39. bboxes2_area = (bboxes2[..., 2] - bboxes2[..., 0]).clamp_(min=0) \
  40. * (bboxes2[..., 3] - bboxes2[..., 1]).clamp_(min=0)
  41. w_intersect = (torch.min(bboxes1[..., 2], bboxes2[..., 2])
  42. - torch.max(bboxes1[..., 0], bboxes2[..., 0])).clamp_(min=0)
  43. h_intersect = (torch.min(bboxes1[..., 3], bboxes2[..., 3])
  44. - torch.max(bboxes1[..., 1], bboxes2[..., 1])).clamp_(min=0)
  45. area_intersect = w_intersect * h_intersect
  46. area_union = bboxes2_area + bboxes1_area - area_intersect
  47. ious = area_intersect / area_union.clamp(min=eps)
  48. if iou_type == "iou":
  49. return ious
  50. elif iou_type == "giou":
  51. g_w_intersect = torch.max(bboxes1[..., 2], bboxes2[..., 2]) \
  52. - torch.min(bboxes1[..., 0], bboxes2[..., 0])
  53. g_h_intersect = torch.max(bboxes1[..., 3], bboxes2[..., 3]) \
  54. - torch.min(bboxes1[..., 1], bboxes2[..., 1])
  55. ac_uion = g_w_intersect * g_h_intersect
  56. gious = ious - (ac_uion - area_union) / ac_uion.clamp(min=eps)
  57. return gious
  58. else:
  59. raise NotImplementedError
  60. # ------------------------- For distributed -------------------------
  61. def is_dist_avail_and_initialized():
  62. if not dist.is_available():
  63. return False
  64. if not dist.is_initialized():
  65. return False
  66. return True
  67. def get_world_size():
  68. if not is_dist_avail_and_initialized():
  69. return 1
  70. return dist.get_world_size()