dn_compoments.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import torch
  2. def inverse_sigmoid(x, eps=1e-5):
  3. x = x.clamp(min=0., max=1.)
  4. return torch.log(x.clamp(min=eps) / (1 - x).clamp(min=eps))
  5. def bbox_cxcywh_to_xyxy(x):
  6. cxcy, wh = torch.split(x, 2, axis=-1)
  7. return torch.cat([cxcy - 0.5 * wh, cxcy + 0.5 * wh], dim=-1)
  8. def bbox_xyxy_to_cxcywh(x):
  9. x1, y1, x2, y2 = x.split(4, axis=-1)
  10. return torch.cat(
  11. [(x1 + x2) / 2, (y1 + y2) / 2, (x2 - x1), (y2 - y1)], axis=-1)
  12. def get_contrastive_denoising_training_group(targets,
  13. num_classes,
  14. num_queries,
  15. class_embed,
  16. num_denoising=100,
  17. label_noise_ratio=0.5,
  18. box_noise_scale=1.0):
  19. if num_denoising <= 0:
  20. return None, None, None, None
  21. num_gts = [len(t) for t in targets["labels"]]
  22. max_gt_num = max(num_gts)
  23. if max_gt_num == 0:
  24. return None, None, None, None
  25. num_group = num_denoising // max_gt_num
  26. num_group = 1 if num_group == 0 else num_group
  27. # pad gt to max_num of a batch
  28. bs = len(targets["labels"])
  29. input_query_class = torch.full([bs, max_gt_num], num_classes).long()
  30. input_query_bbox = torch.zeros([bs, max_gt_num, 4])
  31. pad_gt_mask = torch.zeros([bs, max_gt_num])
  32. for i in range(bs):
  33. num_gt = num_gts[i]
  34. if num_gt > 0:
  35. input_query_class[i, :num_gt] = targets["labels"][i].squeeze(-1)
  36. input_query_bbox[i, :num_gt] = targets["boxes"][i]
  37. pad_gt_mask[i, :num_gt] = 1
  38. # each group has positive and negative queries.
  39. input_query_class = input_query_class.repeat(1, 2 * num_group)
  40. input_query_bbox = input_query_bbox.repeat(1, 2 * num_group, 1)
  41. pad_gt_mask = pad_gt_mask.repeat(1, 2 * num_group)
  42. # positive and negative mask
  43. negative_gt_mask = torch.zeros([bs, max_gt_num * 2, 1])
  44. negative_gt_mask[:, max_gt_num:] = 1
  45. negative_gt_mask = negative_gt_mask.repeat(1, num_group, 1)
  46. positive_gt_mask = 1 - negative_gt_mask
  47. # contrastive denoising training positive index
  48. positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
  49. dn_positive_idx = torch.nonzero(positive_gt_mask)[:, 1]
  50. dn_positive_idx = torch.split(dn_positive_idx, [n * num_group for n in num_gts])
  51. # total denoising queries
  52. num_denoising = int(max_gt_num * 2 * num_group)
  53. if label_noise_ratio > 0:
  54. input_query_class = input_query_class.flatten()
  55. pad_gt_mask = pad_gt_mask.flatten()
  56. # half of bbox prob
  57. mask = torch.rand(input_query_class.shape) < (label_noise_ratio * 0.5)
  58. chosen_idx = torch.nonzero(mask * pad_gt_mask).squeeze(-1)
  59. # randomly put a new one here
  60. new_label = torch.randint_like(
  61. chosen_idx, 0, num_classes, dtype=input_query_class.dtype)
  62. input_query_class.scatter_(chosen_idx, new_label)
  63. input_query_class = input_query_class.reshape(bs, num_denoising)
  64. pad_gt_mask = pad_gt_mask.reshape(bs, num_denoising)
  65. if box_noise_scale > 0:
  66. known_bbox = bbox_cxcywh_to_xyxy(input_query_bbox)
  67. diff = torch.tile(input_query_bbox[..., 2:] * 0.5,
  68. [1, 1, 2]) * box_noise_scale
  69. rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
  70. rand_part = torch.rand(input_query_bbox.shape)
  71. rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (
  72. 1 - negative_gt_mask)
  73. rand_part *= rand_sign
  74. known_bbox += rand_part * diff
  75. known_bbox.clip_(min=0.0, max=1.0)
  76. input_query_bbox = bbox_xyxy_to_cxcywh(known_bbox)
  77. input_query_bbox = inverse_sigmoid(input_query_bbox)
  78. class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]])])
  79. input_query_class = torch.gather(class_embed, 1, input_query_class.flatten())
  80. input_query_class = input_query_class.reshape(bs, num_denoising, -1)
  81. tgt_size = num_denoising + num_queries
  82. attn_mask = torch.ones([tgt_size, tgt_size]) < 0
  83. # match query cannot see the reconstruction
  84. attn_mask[num_denoising:, :num_denoising] = True
  85. # reconstruct cannot see each other
  86. for i in range(num_group):
  87. if i == 0:
  88. attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), max_gt_num *
  89. 2 * (i + 1):num_denoising] = True
  90. if i == num_group - 1:
  91. attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), :max_gt_num *
  92. i * 2] = True
  93. else:
  94. attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), max_gt_num *
  95. 2 * (i + 1):num_denoising] = True
  96. attn_mask[max_gt_num * 2 * i:max_gt_num * 2 * (i + 1), :max_gt_num *
  97. 2 * i] = True
  98. attn_mask = ~attn_mask
  99. dn_meta = {
  100. "dn_positive_idx": dn_positive_idx,
  101. "dn_num_group": num_group,
  102. "dn_num_split": [num_denoising, num_queries]
  103. }
  104. return input_query_class, input_query_bbox, attn_mask, dn_meta