dn_compoments.py 4.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import torch
  2. from .box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh
  3. def inverse_sigmoid(x, eps=1e-5):
  4. x = x.clamp(min=0., max=1.)
  5. return torch.log(x.clamp(min=eps) / (1 - x).clamp(min=eps))
  6. def get_contrastive_denoising_training_group(targets,
  7. num_classes,
  8. num_queries,
  9. class_embed,
  10. num_denoising=100,
  11. label_noise_ratio=0.5,
  12. box_noise_scale=1.0,):
  13. if num_denoising <= 0:
  14. return None, None, None, None
  15. num_gts = [len(t['labels']) for t in targets]
  16. device = targets[0]['labels'].device
  17. max_gt_num = max(num_gts)
  18. if max_gt_num == 0:
  19. return None, None, None, None
  20. num_group = num_denoising // max_gt_num
  21. num_group = 1 if num_group == 0 else num_group
  22. # pad gt to max_num of a batch
  23. bs = len(num_gts)
  24. input_query_class = torch.full([bs, max_gt_num], num_classes, dtype=torch.int32, device=device)
  25. input_query_bbox = torch.zeros([bs, max_gt_num, 4], device=device)
  26. pad_gt_mask = torch.zeros([bs, max_gt_num], dtype=torch.bool, device=device)
  27. for i in range(bs):
  28. num_gt = num_gts[i]
  29. if num_gt > 0:
  30. input_query_class[i, :num_gt] = targets[i]['labels']
  31. input_query_bbox[i, :num_gt] = targets[i]['boxes']
  32. pad_gt_mask[i, :num_gt] = 1
  33. # each group has positive and negative queries.
  34. input_query_class = input_query_class.tile([1, 2 * num_group])
  35. input_query_bbox = input_query_bbox.tile([1, 2 * num_group, 1])
  36. pad_gt_mask = pad_gt_mask.tile([1, 2 * num_group])
  37. # positive and negative mask
  38. negative_gt_mask = torch.zeros([bs, max_gt_num * 2, 1], device=device)
  39. negative_gt_mask[:, max_gt_num:] = 1
  40. negative_gt_mask = negative_gt_mask.tile([1, num_group, 1])
  41. positive_gt_mask = 1 - negative_gt_mask
  42. # contrastive denoising training positive index
  43. positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
  44. dn_positive_idx = torch.nonzero(positive_gt_mask)[:, 1]
  45. dn_positive_idx = torch.split(dn_positive_idx, [n * num_group for n in num_gts])
  46. # total denoising queries
  47. num_denoising = int(max_gt_num * 2 * num_group)
  48. if label_noise_ratio > 0:
  49. mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5)
  50. # randomly put a new one here
  51. new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype)
  52. input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class)
  53. if box_noise_scale > 0:
  54. known_bbox = box_cxcywh_to_xyxy(input_query_bbox)
  55. diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale
  56. rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
  57. rand_part = torch.rand_like(input_query_bbox)
  58. rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask)
  59. rand_part *= rand_sign
  60. known_bbox += rand_part * diff
  61. known_bbox.clip_(min=0.0, max=1.0)
  62. input_query_bbox = box_xyxy_to_cxcywh(known_bbox)
  63. input_query_bbox = inverse_sigmoid(input_query_bbox)
  64. input_query_class = class_embed(input_query_class)
  65. tgt_size = num_denoising + num_queries
  66. # attn_mask = torch.ones([tgt_size, tgt_size], device=device) < 0
  67. attn_mask = torch.full([tgt_size, tgt_size], False, dtype=torch.bool, device=device)
  68. # match query cannot see the reconstruction
  69. attn_mask[num_denoising:, :num_denoising] = True
  70. # reconstruct cannot see each other
  71. for i in range(num_group):
  72. if i == 0:
  73. attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1): num_denoising] = True
  74. if i == num_group - 1:
  75. attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), :max_gt_num * i * 2] = True
  76. else:
  77. attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1): num_denoising] = True
  78. attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), :max_gt_num * 2 * i] = True
  79. dn_meta = {
  80. "dn_positive_idx": dn_positive_idx,
  81. "dn_num_group": num_group,
  82. "dn_num_split": [num_denoising, num_queries]
  83. }
  84. return input_query_class, input_query_bbox, attn_mask, dn_meta