yjh0410 2 жил өмнө
parent
commit
181773c9d8

+ 1 - 40
models/detectors/rtdetr/image_encoder/img_encoder.py

@@ -20,38 +20,6 @@ class ImageEncoder(nn.Module):
         self.csfm = build_fpn(cfg=cfg, in_dims=feats_dim, out_dim=round(cfg['d_model']*cfg['width']))
 
 
-    def position_embedding(self, x, temperature=10000):
-        hs, ws = x.shape[-2:]
-        device = x.device
-        num_pos_feats = x.shape[1] // 2       
-        scale = 2 * 3.141592653589793
-
-        # generate xy coord mat
-        y_embed, x_embed = torch.meshgrid(
-            [torch.arange(1, hs+1, dtype=torch.float32),
-             torch.arange(1, ws+1, dtype=torch.float32)])
-        y_embed = y_embed / (hs + 1e-6) * scale
-        x_embed = x_embed / (ws + 1e-6) * scale
-    
-        # [H, W] -> [1, H, W]
-        y_embed = y_embed[None, :, :].to(device)
-        x_embed = x_embed[None, :, :].to(device)
-
-        dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=device)
-        dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
-        dim_t = temperature ** (2 * dim_t_)
-
-        pos_x = torch.div(x_embed[:, :, :, None], dim_t)
-        pos_y = torch.div(y_embed[:, :, :, None], dim_t)
-        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
-        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
-
-        # [B, C, H, W]
-        pos_embed = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
-        
-        return pos_embed
-        
-
     def forward(self, x):
         # Backbone
         pyramid_feats = self.backbone(x)
@@ -62,14 +30,7 @@ class ImageEncoder(nn.Module):
         # CSFM
         pyramid_feats = self.csfm(pyramid_feats)
 
-        # Prepare memory & memoery_pos for Decoder
-        memory = torch.cat([feat.flatten(2) for feat in pyramid_feats], dim=-1)
-        memory = memory.permute(0, 2, 1).contiguous()
-        memory_pos = torch.cat([self.position_embedding(feat).flatten(2)
-                                for feat in pyramid_feats], dim=-1)
-        memory_pos = memory_pos.permute(0, 2, 1).contiguous()
-
-        return memory, memory_pos
+        return pyramid_feats
 
 
 # build img-encoder

+ 46 - 2
models/detectors/rtdetr/rtdetr.py

@@ -44,6 +44,38 @@ class RTDETR(nn.Module):
 
 
     # ---------------------- Basic Functions ----------------------
+    def position_embedding(self, x, temperature=10000):
+        hs, ws = x.shape[-2:]
+        device = x.device
+        num_pos_feats = x.shape[1] // 2       
+        scale = 2 * 3.141592653589793
+
+        # generate xy coord mat
+        y_embed, x_embed = torch.meshgrid(
+            [torch.arange(1, hs+1, dtype=torch.float32),
+             torch.arange(1, ws+1, dtype=torch.float32)])
+        y_embed = y_embed / (hs + 1e-6) * scale
+        x_embed = x_embed / (ws + 1e-6) * scale
+    
+        # [H, W] -> [1, H, W]
+        y_embed = y_embed[None, :, :].to(device)
+        x_embed = x_embed[None, :, :].to(device)
+
+        dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=device)
+        dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
+        dim_t = temperature ** (2 * dim_t_)
+
+        pos_x = torch.div(x_embed[:, :, :, None], dim_t)
+        pos_y = torch.div(y_embed[:, :, :, None], dim_t)
+        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+
+        # [B, C, H, W]
+        pos_embed = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+        
+        return pos_embed
+        
+
     @torch.jit.unused
     def set_aux_loss(self, outputs_class, outputs_coord):
         # this is a workaround to make torchscript happy, as torchscript
@@ -57,7 +89,13 @@ class RTDETR(nn.Module):
     @torch.no_grad()
     def inference_single_image(self, x):
         # -------------------- Encoder --------------------
-        memory, memory_pos = self.encoder(x)
+        pyramid_feats = self.encoder(x)
+
+        # -------------------- Pos Embed --------------------
+        memory = torch.cat([feat.flatten(2) for feat in pyramid_feats], dim=-1)
+        memory_pos = torch.cat([self.position_embedding(feat).flatten(2) for feat in pyramid_feats], dim=-1)
+        memory = memory.permute(0, 2, 1).contiguous()
+        memory_pos = memory_pos.permute(0, 2, 1).contiguous()
 
         # -------------------- Decoder --------------------
         hs, reference = self.decoder(memory, memory_pos)
@@ -93,8 +131,14 @@ class RTDETR(nn.Module):
             return self.inference_single_image(x)
         else:
             # -------------------- Encoder --------------------
-            memory, memory_pos = self.encoder(x)
+            pyramid_feats = self.encoder(x)
 
+            # -------------------- Pos Embed --------------------
+            memory = torch.cat([feat.flatten(2) for feat in pyramid_feats], dim=-1)
+            memory_pos = torch.cat([self.position_embedding(feat).flatten(2) for feat in pyramid_feats], dim=-1)
+            memory = memory.permute(0, 2, 1).contiguous()
+            memory_pos = memory_pos.permute(0, 2, 1).contiguous()
+            
             # -------------------- Decoder --------------------
             hs, reference = self.decoder(memory, memory_pos)