Browse Source

debug RT-DETR on gpu

yjh0410 1 year ago
parent
commit
cdd96c7b39

+ 1 - 1
dataset/data_augment/rtdetr_augment.py

@@ -364,7 +364,7 @@ class RTDetrBaseTransform(object):
         ])
 
 
-    def __call__(self, image, target, mosaic=False):
+    def __call__(self, image, target=None, mosaic=False):
         orig_h, orig_w = image.shape[:2]
         ratio = [self.img_size / orig_w, self.img_size / orig_h]
 

+ 9 - 9
models/detectors/rtdetr/basic_modules/dn_compoments.py

@@ -37,10 +37,10 @@ def get_contrastive_denoising_training_group(targets,
     # pad gt to max_num of a batch
     bs = len(targets)
     # [bs, max_gt_num]
-    input_query_class = torch.full([bs, max_gt_num], num_classes).long()
+    input_query_class = torch.full([bs, max_gt_num], num_classes, device=class_embed.device).long()
     # [bs, max_gt_num, 4]
-    input_query_bbox = torch.zeros([bs, max_gt_num, 4])
-    pad_gt_mask = torch.zeros([bs, max_gt_num])
+    input_query_bbox = torch.zeros([bs, max_gt_num, 4], device=class_embed.device)
+    pad_gt_mask = torch.zeros([bs, max_gt_num], device=class_embed.device)
     for i in range(bs):
         num_gt = num_gts[i]
         if num_gt > 0:
@@ -54,7 +54,7 @@ def get_contrastive_denoising_training_group(targets,
     pad_gt_mask = pad_gt_mask.repeat(1, 2 * num_group)
 
     # positive and negative mask
-    negative_gt_mask = torch.zeros([bs, max_gt_num * 2, 1])
+    negative_gt_mask = torch.zeros([bs, max_gt_num * 2, 1], device=class_embed.device)
     negative_gt_mask[:, max_gt_num:] = 1
     negative_gt_mask = negative_gt_mask.repeat(1, num_group, 1)
     positive_gt_mask = 1 - negative_gt_mask
@@ -71,11 +71,11 @@ def get_contrastive_denoising_training_group(targets,
         input_query_class = input_query_class.flatten()  # [bs * num_denoising]
         pad_gt_mask = pad_gt_mask.flatten()
         # half of bbox prob
-        mask = torch.rand(input_query_class.shape) < (label_noise_ratio * 0.5)
+        mask = torch.rand(input_query_class.shape, device=class_embed.device) < (label_noise_ratio * 0.5)
         chosen_idx = torch.nonzero(mask * pad_gt_mask).squeeze(-1)
         # randomly put a new one here
         new_label = torch.randint_like(
-            chosen_idx, 0, num_classes, dtype=input_query_class.dtype)
+            chosen_idx, 0, num_classes, dtype=input_query_class.dtype, device=class_embed.device)
         # [bs * num_denoising]
         input_query_class = torch.scatter(input_query_class, 0, chosen_idx, new_label)
         # input_query_class.scatter_(chosen_idx, new_label)
@@ -89,7 +89,7 @@ def get_contrastive_denoising_training_group(targets,
                            [1, 1, 2]) * box_noise_scale
 
         rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
-        rand_part = torch.rand(input_query_bbox.shape)
+        rand_part = torch.rand(input_query_bbox.shape, device=class_embed.device)
         rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (
             1 - negative_gt_mask)
         rand_part *= rand_sign
@@ -99,7 +99,7 @@ def get_contrastive_denoising_training_group(targets,
         input_query_bbox = inverse_sigmoid(input_query_bbox)
 
     # [num_classes + 1, hidden_dim]
-    class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]])])
+    class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)])
     # input_query_class = paddle.gather(class_embed, input_query_class.flatten(), axis=0)
 
     # input_query_class: [bs, num_denoising] -> [bs*num_denoising, hidden_dim]
@@ -108,7 +108,7 @@ def get_contrastive_denoising_training_group(targets,
     input_query_class = input_query_class.reshape(bs, num_denoising, -1)
     
     tgt_size = num_denoising + num_queries
-    attn_mask = torch.ones([tgt_size, tgt_size]) < 0
+    attn_mask = torch.ones([tgt_size, tgt_size], device=class_embed.device) < 0
     # match query cannot see the reconstruction
     attn_mask[num_denoising:, :num_denoising] = True
     # reconstruct cannot see each other

+ 2 - 2
models/detectors/rtdetr/basic_modules/transformer.py

@@ -405,8 +405,8 @@ class DeformableTransformerDecoderLayer(nn.Module):
         if attn_mask is not None:
             attn_mask = torch.where(
                 attn_mask.bool(),
-                torch.zeros(attn_mask.shape, dtype=tgt.dtype),
-                torch.full(attn_mask.shape, float("-inf"), dtype=tgt.dtype))
+                torch.zeros(attn_mask.shape, dtype=tgt.dtype, device=attn_mask.device),
+                torch.full(attn_mask.shape, float("-inf"), dtype=tgt.dtype, device=attn_mask.device))
         tgt2 = self.self_attn(q, k, value=tgt)[0]
         tgt = tgt + self.dropout1(tgt2)
         tgt = self.norm1(tgt)

+ 7 - 7
models/detectors/rtdetr/loss.py

@@ -37,8 +37,8 @@ class Criterion(object):
     def __call__(self, dec_out_bboxes, dec_out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta, targets=None):
         assert targets is not None
 
-        gt_labels = [t['labels'] for t in targets]  # (List[torch.Tensor]) -> List[[N,]]
-        gt_boxes  = [t['boxes']  for t in targets]  # (List[torch.Tensor]) -> List[[N, 4]]
+        gt_labels = [t['labels'].to(dec_out_bboxes.device) for t in targets]  # (List[torch.Tensor]) -> List[[N,]]
+        gt_boxes  = [t['boxes'].to(dec_out_bboxes.device)  for t in targets]  # (List[torch.Tensor]) -> List[[N, 4]]
 
         if dn_meta is not None:
             if isinstance(dn_meta, list):
@@ -147,7 +147,7 @@ class DETRLoss(nn.Module):
         # logits: [b, query, num_classes], gt_class: list[[n, 1]]
         name_class = "loss_class" + postfix
 
-        target_label = torch.full(logits.shape[:2], bg_index).long()
+        target_label = torch.full(logits.shape[:2], bg_index, device=logits.device).long()
         bs, num_query_objects = target_label.shape
         num_gt = sum(len(a) for a in gt_class)
         if num_gt > 0:
@@ -161,7 +161,7 @@ class DETRLoss(nn.Module):
         # one-hot label
         target_label = F.one_hot(target_label, self.num_classes + 1)[..., :-1].float()
         if iou_score is not None and self.use_vfl:
-            target_score = torch.zeros([bs, num_query_objects])
+            target_score = torch.zeros([bs, num_query_objects], device=logits.device)
             if num_gt > 0:
                 target_score = target_score.reshape(-1, 1)
                 target_score[index] = iou_score.float()
@@ -260,16 +260,16 @@ class DETRLoss(nn.Module):
         src_idx = torch.cat([src for (src, _) in match_indices])
         src_idx += (batch_idx * num_query_objects)
         target_assign = torch.cat([
-            torch.gather(t, 0, dst) for t, (_, dst) in zip(target, match_indices)
+            torch.gather(t, 0, dst.to(t.device)) for t, (_, dst) in zip(target, match_indices)
         ])
         return src_idx, target_assign
 
     def _get_src_target_assign(self, src, target, match_indices):
-        src_assign = torch.cat([t[I] if len(I) > 0 else torch.zeros([0, t.shape[-1]])
+        src_assign = torch.cat([t[I] if len(I) > 0 else torch.zeros([0, t.shape[-1]], device=src.device)
             for t, (I, _) in zip(src, match_indices)
         ])
 
-        target_assign = torch.cat([t[J] if len(J) > 0 else torch.zeros([0, t.shape[-1]])
+        target_assign = torch.cat([t[J] if len(J) > 0 else torch.zeros([0, t.shape[-1]], device=src.device)
             for t, (_, J) in zip(target, match_indices)
         ])
 

+ 7 - 1
models/detectors/rtdetr/rtdetr.py

@@ -101,7 +101,13 @@ class RT_DETR(nn.Module):
             # post-process
             bboxes, scores, labels = self.post_process(box_preds, cls_preds)
 
-            return bboxes, scores, labels
+            outputs = {
+                "scores": scores.cpu().numpy(),
+                "labels": labels.cpu().numpy(),
+                "bboxes": bboxes.cpu().numpy(),
+            }
+
+            return outputs
         
 
 if __name__ == '__main__':

+ 2 - 0
models/detectors/rtdetr/rtdetr_decoder.py

@@ -214,6 +214,8 @@ class RTDETRTransformer(nn.Module):
         bs, _, _ = memory.shape
         # prepare input for decoder
         anchors, valid_mask = self.generate_anchors(spatial_shapes)
+        anchors = anchors.to(memory.device)
+        valid_mask = valid_mask.to(memory.device)
         memory = torch.where(valid_mask, memory, torch.as_tensor(0.))
         output_memory = self.enc_output(memory)