|
|
@@ -37,10 +37,10 @@ def get_contrastive_denoising_training_group(targets,
|
|
|
# pad gt to max_num of a batch
|
|
|
bs = len(targets)
|
|
|
# [bs, max_gt_num]
|
|
|
- input_query_class = torch.full([bs, max_gt_num], num_classes).long()
|
|
|
+ input_query_class = torch.full([bs, max_gt_num], num_classes, device=class_embed.device).long()
|
|
|
# [bs, max_gt_num, 4]
|
|
|
- input_query_bbox = torch.zeros([bs, max_gt_num, 4])
|
|
|
- pad_gt_mask = torch.zeros([bs, max_gt_num])
|
|
|
+ input_query_bbox = torch.zeros([bs, max_gt_num, 4], device=class_embed.device)
|
|
|
+ pad_gt_mask = torch.zeros([bs, max_gt_num], device=class_embed.device)
|
|
|
for i in range(bs):
|
|
|
num_gt = num_gts[i]
|
|
|
if num_gt > 0:
|
|
|
@@ -54,7 +54,7 @@ def get_contrastive_denoising_training_group(targets,
|
|
|
pad_gt_mask = pad_gt_mask.repeat(1, 2 * num_group)
|
|
|
|
|
|
# positive and negative mask
|
|
|
- negative_gt_mask = torch.zeros([bs, max_gt_num * 2, 1])
|
|
|
+ negative_gt_mask = torch.zeros([bs, max_gt_num * 2, 1], device=class_embed.device)
|
|
|
negative_gt_mask[:, max_gt_num:] = 1
|
|
|
negative_gt_mask = negative_gt_mask.repeat(1, num_group, 1)
|
|
|
positive_gt_mask = 1 - negative_gt_mask
|
|
|
@@ -71,11 +71,11 @@ def get_contrastive_denoising_training_group(targets,
|
|
|
input_query_class = input_query_class.flatten() # [bs * num_denoising]
|
|
|
pad_gt_mask = pad_gt_mask.flatten()
|
|
|
# half of bbox prob
|
|
|
- mask = torch.rand(input_query_class.shape) < (label_noise_ratio * 0.5)
|
|
|
+ mask = torch.rand(input_query_class.shape, device=class_embed.device) < (label_noise_ratio * 0.5)
|
|
|
chosen_idx = torch.nonzero(mask * pad_gt_mask).squeeze(-1)
|
|
|
# randomly put a new one here
|
|
|
new_label = torch.randint_like(
|
|
|
- chosen_idx, 0, num_classes, dtype=input_query_class.dtype)
|
|
|
+ chosen_idx, 0, num_classes, dtype=input_query_class.dtype, device=class_embed.device)
|
|
|
# [bs * num_denoising]
|
|
|
input_query_class = torch.scatter(input_query_class, 0, chosen_idx, new_label)
|
|
|
# input_query_class.scatter_(chosen_idx, new_label)
|
|
|
@@ -89,7 +89,7 @@ def get_contrastive_denoising_training_group(targets,
|
|
|
[1, 1, 2]) * box_noise_scale
|
|
|
|
|
|
rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
|
|
|
- rand_part = torch.rand(input_query_bbox.shape)
|
|
|
+ rand_part = torch.rand(input_query_bbox.shape, device=class_embed.device)
|
|
|
rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (
|
|
|
1 - negative_gt_mask)
|
|
|
rand_part *= rand_sign
|
|
|
@@ -99,7 +99,7 @@ def get_contrastive_denoising_training_group(targets,
|
|
|
input_query_bbox = inverse_sigmoid(input_query_bbox)
|
|
|
|
|
|
# [num_classes + 1, hidden_dim]
|
|
|
- class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]])])
|
|
|
+ class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)])
|
|
|
# input_query_class = paddle.gather(class_embed, input_query_class.flatten(), axis=0)
|
|
|
|
|
|
# input_query_class: [bs, num_denoising] -> [bs*num_denoising, hidden_dim]
|
|
|
@@ -108,7 +108,7 @@ def get_contrastive_denoising_training_group(targets,
|
|
|
input_query_class = input_query_class.reshape(bs, num_denoising, -1)
|
|
|
|
|
|
tgt_size = num_denoising + num_queries
|
|
|
- attn_mask = torch.ones([tgt_size, tgt_size]) < 0
|
|
|
+ attn_mask = torch.ones([tgt_size, tgt_size], device=class_embed.device) < 0
|
|
|
# match query cannot see the reconstruction
|
|
|
attn_mask[num_denoising:, :num_denoising] = True
|
|
|
# reconstruct cannot see each other
|